diff --git a/3rdparty/flashinfer b/3rdparty/flashinfer index 9cd1f42e968a..47686efcad18 160000 --- a/3rdparty/flashinfer +++ b/3rdparty/flashinfer @@ -1 +1 @@ -Subproject commit 9cd1f42e968a8de7d3af2c7567072e0ad6c8ffed +Subproject commit 47686efcad186096250ba0f1209ed63ceaaeea58 diff --git a/src/runtime/relax_vm/paged_kv_cache.cc b/src/runtime/relax_vm/paged_kv_cache.cc index 8e126b057f4e..a8c38ca4ed3d 100644 --- a/src/runtime/relax_vm/paged_kv_cache.cc +++ b/src/runtime/relax_vm/paged_kv_cache.cc @@ -48,6 +48,8 @@ namespace relax_vm { * prefixes) in paged KV cache. */ constexpr const int kPagedKVCacheMaxBlockDepth = 5; +/*! \brief The 8MB workspace size for attention auxiliary data. */ +constexpr const int kAttnWorkspaceByte = 8 * 1024 * 1024; /*! * \brief The block structure in paged KV cache with common prefix support. @@ -279,6 +281,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCache { NDArray temp_attn_output_device_; NDArray temp_attn_scores_device_; NDArray merged_attn_scores_device_; + std::vector temp_attn_workspace_; //------------------------------------------- // For efficient memory management, the actual sizes of the arrays @@ -381,12 +384,17 @@ class PagedAttentionKVCacheObj : public AttentionKVCache { last_page_len_on_depths_device_.push_back( NDArray::Empty({reserved_num_seqs}, dtype_aux_, device)); k_rope_pos_offset_device_.push_back(NDArray::Empty({reserved_num_seqs}, dtype_aux_, device)); + temp_attn_workspace_.push_back( + NDArray::Empty({kAttnWorkspaceByte / 4}, DataType::Float(32), device)); qo_indptr_on_depths_view_.push_back(NDArray()); page_indptr_on_depths_view_.push_back(NDArray()); page_indices_on_depths_view_.push_back(NDArray()); last_page_len_on_depths_view_.push_back(NDArray()); k_rope_pos_offset_view_.push_back(NDArray()); } + // Additional workspace for the "prefill with ragged kv" kernel. + temp_attn_workspace_.push_back( + NDArray::Empty({kAttnWorkspaceByte / 4}, DataType::Float(32), device)); cur_append_length_indptr_device_ = NDArray::Empty({reserved_num_seqs + 1}, dtype_aux_, device); k_ragged_rope_pos_offset_device_ = NDArray::Empty({reserved_num_seqs}, dtype_aux_, device); q_rope_position_map_device_ = NDArray::Empty({prefill_chunk_size_}, dtype_aux_, device); @@ -679,7 +687,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCache { // Apply rotary embedding to q/k data. f_rotary_inplace_(q_data, k_data, cur_append_length_indptr_view_, k_ragged_rope_pos_offset_view_, cur_batch_size_, num_qo_heads_, - num_kv_heads_, head_dim_, /*qkv_layout=*/0, rotary_scale_, rotary_theta_); + num_kv_heads_, head_dim_, rotary_scale_, rotary_theta_); } // Part 3: append k/v data to kv-cache @@ -945,25 +953,26 @@ class PagedAttentionKVCacheObj : public AttentionKVCache { if (append_before_attn_) { f_attention_decode_begin_forward_.value()( - /*depth=*/0, page_indptr_on_depths_view_[0], last_page_len_on_depths_view_[0], - num_qo_heads_, num_kv_heads_, head_dim_, page_size_, + /*depth=*/0, temp_attn_workspace_[1], page_indptr_on_depths_view_[0], + last_page_len_on_depths_view_[0], num_qo_heads_, num_kv_heads_, head_dim_, page_size_, /*rotary_mode=*/rope_mode_ == RoPEMode::kInline); } else { f_attention_prefill_ragged_begin_forward_.value()( - cur_append_length_indptr_view_, cur_batch_size_, num_qo_heads_, num_kv_heads_); + temp_attn_workspace_[0], cur_append_length_indptr_view_, cur_batch_size_, num_qo_heads_, + num_kv_heads_); for (int d = 0; d < num_depths_; ++d) { if (page_indices_on_depths_view_[d]->shape[0] == 0) { continue; } if (use_decode_kernel_[d]) { f_attention_decode_begin_forward_.value()( - d, page_indptr_on_depths_view_[d], last_page_len_on_depths_view_[d], num_qo_heads_, - num_kv_heads_, head_dim_, page_size_, + d, temp_attn_workspace_[d + 1], page_indptr_on_depths_view_[d], + last_page_len_on_depths_view_[d], num_qo_heads_, num_kv_heads_, head_dim_, page_size_, /*rotary_mode=*/rope_mode_ == RoPEMode::kInline); } else { - f_attention_prefill_begin_forward_.value()(/*depth=*/d, qo_indptr_on_depths_view_[d], - last_page_len_on_depths_view_[d]->shape[0], - num_qo_heads_, num_kv_heads_); + f_attention_prefill_begin_forward_.value()( + /*depth=*/d, temp_attn_workspace_[d + 1], qo_indptr_on_depths_view_[d], + last_page_len_on_depths_view_[d]->shape[0], num_qo_heads_, num_kv_heads_); } } } diff --git a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py index e4c066342b65..dd4b0ea763bb 100644 --- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py +++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py @@ -514,9 +514,8 @@ def tir_rotary( _1: T.int32, _2: T.int32, _3: T.int32, - _4: T.int32, + _4: T.float32, _5: T.float32, - _6: T.float32, ): T.func_attr({"tir.is_scheduled": 1}) total_len = T.int32()