@@ -178,10 +178,6 @@ struct Sequence {
178178 }
179179 block_ptr = block.parent_idx ;
180180 }
181- CHECK_LE (depth, kPagedKVCacheMaxBlockDepth )
182- << " Paged KV cache supports one sequence to reuse " << kPagedKVCacheMaxBlockDepth
183- << " prefixes (the fork depth) at most. However, the given sequence has fork depth "
184- << depth;
185181 }
186182
187183 std::vector<int32_t > GetBlockTrace (const std::vector<Block>& global_block_pool) const {
@@ -1490,19 +1486,29 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
14901486 is_chain_ = true ;
14911487 }
14921488
1493- std::vector<std::vector<int32_t >> block_ids_on_depths = GetBlockIdsOnDepth (sequences);
1494- num_depths_ = block_ids_on_depths.size ();
1489+ auto [block_ids_on_depths, trailing_blocks] = GetBlockIdsOnDepth (sequences);
1490+ num_depths_ =
1491+ std::min (static_cast <int >(block_ids_on_depths.size ()), kPagedKVCacheMaxBlockDepth );
14951492 ICHECK_LE (num_depths_, kPagedKVCacheMaxBlockDepth );
14961493
14971494 std::vector<std::vector<std::pair<int32_t , int32_t >>> chunked_block_ids_arr;
14981495 chunked_block_ids_arr.reserve (num_depths_);
14991496 use_decode_kernel_.clear ();
15001497 for (int d = 0 ; d < num_depths_; ++d) {
1501- auto [chunked_block_ids, use_decode_kernel] = GetChunkedBlockIds (block_ids_on_depths[d]);
1498+ // We force the blocks at maximum depth not to coalesce, so that it can be concatenated with
1499+ // trailing exceeding blocks.
1500+ auto [chunked_block_ids, use_decode_kernel] = GetChunkedBlockIds (
1501+ block_ids_on_depths[d], /* enable_coalesce=*/ d != kPagedKVCacheMaxBlockDepth - 1 );
15021502 chunked_block_ids_arr.push_back (chunked_block_ids);
15031503 use_decode_kernel_.push_back (use_decode_kernel);
15041504 }
15051505
1506+ if (num_depths_ == kPagedKVCacheMaxBlockDepth ) {
1507+ // Since we force the blocks at maximum depth not to coalesce, the output blocks at maximum
1508+ // depth must have the same size as current batch.
1509+ CHECK_EQ (chunked_block_ids_arr[num_depths_ - 1 ].size (), cur_batch_size_);
1510+ }
1511+
15061512 append_before_attn_ = !support_sliding_window_ && num_depths_ == 1 && use_decode_kernel_[0 ];
15071513 if (append_before_attn_) {
15081514 // Right now we use different kernels when depth is 1 or not 1.
@@ -1530,7 +1536,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
15301536 k_rope_pos_offset_h.clear ();
15311537 qo_indptr_h.push_back (0 );
15321538 page_indptr_h.push_back (0 );
1533- for (const auto & [block_id, chunk_append_length] : chunked_block_ids_arr[d]) {
1539+ for (int i = 0 ; i < static_cast <int >(chunked_block_ids_arr[d].size ()); ++i) {
1540+ const auto & [block_id, chunk_append_length] = chunked_block_ids_arr[d][i];
15341541 qo_indptr_h.push_back (qo_indptr_h.back () + chunk_append_length);
15351542 if (block_id == -1 ) {
15361543 page_indptr_h.push_back (page_indptr_h.back ());
@@ -1539,19 +1546,53 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
15391546 sink_size_h.push_back (0 );
15401547 k_rope_pos_offset_h.push_back (0 );
15411548 } else {
1542- const Block& block = global_block_pool_[block_id];
1543- page_indptr_h.push_back (page_indptr_h.back () + block.page_ids .size ());
1544- for (int32_t page_id : block.page_ids ) {
1545- page_indices_h.push_back (page_id);
1549+ if (d < kPagedKVCacheMaxBlockDepth - 1 ) {
1550+ // Blocks not at maximum depth
1551+ const Block& block = global_block_pool_[block_id];
1552+ page_indptr_h.push_back (page_indptr_h.back () + block.page_ids .size ());
1553+ for (int32_t page_id : block.page_ids ) {
1554+ page_indices_h.push_back (page_id);
1555+ }
1556+ last_page_len_h.push_back (
1557+ block.seq_length == 0
1558+ ? 0
1559+ : (block.seq_length - block.sink_length + block.sliding_window_offset - 1 ) %
1560+ page_size_ +
1561+ 1 );
1562+ sliding_window_offset_h.push_back (block.sliding_window_offset );
1563+ sink_size_h.push_back (block.sink_length );
1564+ k_rope_pos_offset_h.push_back (block.start_pos );
1565+ } else {
1566+ // Blocks at maximum depth
1567+ const Block& block = global_block_pool_[block_id];
1568+ int32_t num_pages = static_cast <int32_t >(block.page_ids .size ());
1569+ int32_t total_seq_length = static_cast <int32_t >(block.seq_length );
1570+ int32_t last_block_id = block_id;
1571+ for (int32_t page_id : block.page_ids ) {
1572+ page_indices_h.push_back (page_id);
1573+ }
1574+ for (int32_t id : trailing_blocks[i]) {
1575+ // Collect trailing blocks if available
1576+ const Block& block = global_block_pool_[id];
1577+ for (int32_t page_id : block.page_ids ) {
1578+ page_indices_h.push_back (page_id);
1579+ }
1580+ num_pages += block.page_ids .size ();
1581+ total_seq_length += block.seq_length ;
1582+ last_block_id = id;
1583+ }
1584+ page_indptr_h.push_back (page_indptr_h.back () + num_pages);
1585+ const Block& last_block = global_block_pool_[last_block_id];
1586+ last_page_len_h.push_back (total_seq_length == 0
1587+ ? 0
1588+ : (total_seq_length - last_block.sink_length +
1589+ last_block.sliding_window_offset - 1 ) %
1590+ page_size_ +
1591+ 1 );
1592+ sliding_window_offset_h.push_back (last_block.sliding_window_offset );
1593+ sink_size_h.push_back (last_block.sink_length );
1594+ k_rope_pos_offset_h.push_back (block.start_pos );
15461595 }
1547- last_page_len_h.push_back (block.seq_length == 0 ? 0
1548- : (block.seq_length - block.sink_length +
1549- block.sliding_window_offset - 1 ) %
1550- page_size_ +
1551- 1 );
1552- sliding_window_offset_h.push_back (block.sliding_window_offset );
1553- sink_size_h.push_back (block.sink_length );
1554- k_rope_pos_offset_h.push_back (block.start_pos );
15551596 }
15561597 }
15571598 }
@@ -2035,22 +2076,34 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
20352076 /* !
20362077 * \brief For the given list of sequences, check the block trace of
20372078 * each sequence, and return the blocks ids used by the sequences
2038- * on each depth.
2079+ * on each depth. And if the depth is larger than the kPagedKVCacheMaxBlockDepth,
2080+ * the exceeding blocks will concatenate and output separately.
20392081 * More precisely, the inner returned vector contains the block ids
20402082 * used by the sequences on a certain depth (or "-1" if a sequence
20412083 * has fewer depth). The outer returned vector contains the inner
20422084 * vectors from the lowest depth to the highest depth.
20432085 */
2044- std::vector<std::vector<int32_t >> GetBlockIdsOnDepth (
2045- const std::vector<Sequence*>& sequences) const {
2086+ std::pair<std:: vector<std::vector<int32_t >>, std::vector<std::vector< int32_t >>>
2087+ GetBlockIdsOnDepth ( const std::vector<Sequence*>& sequences) const {
20462088 // - Get the trace of each sequence.
20472089 int64_t num_depths = 0 ;
20482090 std::vector<std::vector<int32_t >> seq_block_traces;
2091+ std::vector<std::vector<int32_t >> trailing_block_traces;
20492092 seq_block_traces.reserve (cur_batch_size_);
2093+ trailing_block_traces.reserve (cur_batch_size_);
20502094 for (int i = 0 ; i < cur_batch_size_; ++i) {
20512095 std::vector<int32_t > trace = sequences[i]->GetBlockTrace (global_block_pool_);
2052- num_depths = std::max (num_depths, static_cast <int64_t >(trace.size ()));
2053- seq_block_traces.push_back (std::move (trace));
2096+ if (static_cast <int >(trace.size ()) <= kPagedKVCacheMaxBlockDepth ) {
2097+ seq_block_traces.push_back (std::vector<int32_t >(trace.begin (), trace.end ()));
2098+ trailing_block_traces.push_back ({});
2099+ num_depths = std::max (num_depths, static_cast <int64_t >(trace.size ()));
2100+ } else {
2101+ seq_block_traces.push_back (
2102+ std::vector<int32_t >(trace.begin (), trace.begin () + kPagedKVCacheMaxBlockDepth ));
2103+ trailing_block_traces.push_back (
2104+ std::vector<int32_t >(trace.begin () + kPagedKVCacheMaxBlockDepth , trace.end ()));
2105+ num_depths = std::max (num_depths, static_cast <int64_t >(kPagedKVCacheMaxBlockDepth ));
2106+ }
20542107 }
20552108
20562109 // "Transpose" the traces, yielding the block ids used on each depth.
@@ -2065,7 +2118,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
20652118 }
20662119 block_ids_on_depths.push_back (std::move (block_ids));
20672120 }
2068- return block_ids_on_depths;
2121+ return { block_ids_on_depths, trailing_block_traces} ;
20692122 }
20702123
20712124 /* !
@@ -2081,7 +2134,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
20812134 * input blocks.
20822135 */
20832136 std::pair<std::vector<std::pair<int32_t , int32_t >>, bool > GetChunkedBlockIds (
2084- const std::vector<int32_t >& block_ids) const {
2137+ const std::vector<int32_t >& block_ids, bool enable_coalesce = true ) const {
20852138 std::vector<std::pair<int32_t , int32_t >> uncoalesced_block_ids;
20862139 std::vector<std::pair<int32_t , int32_t >> coalesced_block_ids;
20872140
@@ -2115,8 +2168,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
21152168 double coalesce_ratio = 1.0 * page_counter_uncoalesced / page_counter_coalesced;
21162169 // Do not coalesce and use batch decode kernel when coalesce ratio is small.
21172170 bool use_decode_kernel = is_decode_request_ && coalesce_ratio < 1.1 ;
2118-
2119- return {use_decode_kernel ? uncoalesced_block_ids : coalesced_block_ids, use_decode_kernel};
2171+ return {use_decode_kernel || !enable_coalesce ? uncoalesced_block_ids : coalesced_block_ids,
2172+ use_decode_kernel};
21202173 }
21212174
21222175 /* ! \brief Invoke the "begin forward" functions of underlying kernels. */
0 commit comments