Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/flashinfer
Submodule flashinfer updated 56 files
+1 −1 .clang-format
+54 −0 .github/workflows/build-doc.yml
+106 −0 .github/workflows/release_wheel.yml
+1 −0 docs/.gitignore
+20 −0 docs/Makefile
+86 −0 docs/conf.py
+20 −0 docs/index.rst
+35 −0 docs/make.bat
+9 −0 docs/requirements.txt
+127 −19 include/flashinfer/cascade.cuh
+179 −223 include/flashinfer/decode.cuh
+148 −124 include/flashinfer/handler.cuh
+12 −12 include/flashinfer/layout.cuh
+99 −122 include/flashinfer/page.cuh
+125 −117 include/flashinfer/prefill.cuh
+26 −47 include/flashinfer/rope.cuh
+85 −69 include/flashinfer/utils.cuh
+12 −0 python/MANIFEST.in
+49 −28 python/csrc/batch_decode.cu
+74 −50 python/csrc/batch_prefill.cu
+35 −2 python/csrc/cascade.cu
+5 −2 python/csrc/flashinfer_ops.cu
+26 −13 python/csrc/flashinfer_ops.h
+84 −0 python/csrc/page.cu
+1 −1 python/csrc/single_decode.cu
+15 −11 python/flashinfer/__init__.py
+323 −0 python/flashinfer/cascade.py
+327 −0 python/flashinfer/decode.py
+0 −680 python/flashinfer/ops/__init__.py
+0 −12 python/flashinfer/ops/utils.py
+42 −0 python/flashinfer/page.py
+324 −0 python/flashinfer/prefill.py
+56 −0 python/flashinfer/utils.py
+1 −0 python/include
+69 −40 python/setup.py
+29 −12 python/tests/test_batch_decode_kernels.py
+38 −13 python/tests/test_batch_prefill_kernels.py
+204 −3 python/tests/test_shared_prefix_kernels.py
+1 −0 python/version.txt
+8 −0 scripts/ci-flashinfer.env.example
+27 −0 scripts/ci-flashinfer.service
+0 −0 scripts/formatter.sh
+46 −0 scripts/run-ci-build-wheel.sh
+22 −12 src/bench_batch_decode.cu
+53 −35 src/bench_cascade.cu
+7 −7 src/bench_single_decode.cu
+4 −4 src/bench_single_prefill.cu
+11 −8 src/cpu_reference.h
+19 −14 src/test_batch_decode.cu
+16 −12 src/test_batch_prefill.cu
+190 −44 src/test_cascade.cu
+76 −43 src/test_page.cu
+11 −10 src/test_single_decode.cu
+11 −10 src/test_single_prefill.cu
+53 −41 src/tvm_wrapper.cu
+1 −0 version.txt
27 changes: 18 additions & 9 deletions src/runtime/relax_vm/paged_kv_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ namespace relax_vm {
* prefixes) in paged KV cache.
*/
constexpr const int kPagedKVCacheMaxBlockDepth = 5;
/*! \brief The 8MB workspace size for attention auxiliary data. */
constexpr const int kAttnWorkspaceByte = 8 * 1024 * 1024;

/*!
* \brief The block structure in paged KV cache with common prefix support.
Expand Down Expand Up @@ -279,6 +281,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCache {
NDArray temp_attn_output_device_;
NDArray temp_attn_scores_device_;
NDArray merged_attn_scores_device_;
std::vector<NDArray> temp_attn_workspace_;

//-------------------------------------------
// For efficient memory management, the actual sizes of the arrays
Expand Down Expand Up @@ -381,12 +384,17 @@ class PagedAttentionKVCacheObj : public AttentionKVCache {
last_page_len_on_depths_device_.push_back(
NDArray::Empty({reserved_num_seqs}, dtype_aux_, device));
k_rope_pos_offset_device_.push_back(NDArray::Empty({reserved_num_seqs}, dtype_aux_, device));
temp_attn_workspace_.push_back(
NDArray::Empty({kAttnWorkspaceByte / 4}, DataType::Float(32), device));
qo_indptr_on_depths_view_.push_back(NDArray());
page_indptr_on_depths_view_.push_back(NDArray());
page_indices_on_depths_view_.push_back(NDArray());
last_page_len_on_depths_view_.push_back(NDArray());
k_rope_pos_offset_view_.push_back(NDArray());
}
// Additional workspace for the "prefill with ragged kv" kernel.
temp_attn_workspace_.push_back(
NDArray::Empty({kAttnWorkspaceByte / 4}, DataType::Float(32), device));
cur_append_length_indptr_device_ = NDArray::Empty({reserved_num_seqs + 1}, dtype_aux_, device);
k_ragged_rope_pos_offset_device_ = NDArray::Empty({reserved_num_seqs}, dtype_aux_, device);
q_rope_position_map_device_ = NDArray::Empty({prefill_chunk_size_}, dtype_aux_, device);
Expand Down Expand Up @@ -679,7 +687,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCache {
// Apply rotary embedding to q/k data.
f_rotary_inplace_(q_data, k_data, cur_append_length_indptr_view_,
k_ragged_rope_pos_offset_view_, cur_batch_size_, num_qo_heads_,
num_kv_heads_, head_dim_, /*qkv_layout=*/0, rotary_scale_, rotary_theta_);
num_kv_heads_, head_dim_, rotary_scale_, rotary_theta_);
}

// Part 3: append k/v data to kv-cache
Expand Down Expand Up @@ -945,25 +953,26 @@ class PagedAttentionKVCacheObj : public AttentionKVCache {

if (append_before_attn_) {
f_attention_decode_begin_forward_.value()(
/*depth=*/0, page_indptr_on_depths_view_[0], last_page_len_on_depths_view_[0],
num_qo_heads_, num_kv_heads_, head_dim_, page_size_,
/*depth=*/0, temp_attn_workspace_[1], page_indptr_on_depths_view_[0],
last_page_len_on_depths_view_[0], num_qo_heads_, num_kv_heads_, head_dim_, page_size_,
/*rotary_mode=*/rope_mode_ == RoPEMode::kInline);
} else {
f_attention_prefill_ragged_begin_forward_.value()(
cur_append_length_indptr_view_, cur_batch_size_, num_qo_heads_, num_kv_heads_);
temp_attn_workspace_[0], cur_append_length_indptr_view_, cur_batch_size_, num_qo_heads_,
num_kv_heads_);
for (int d = 0; d < num_depths_; ++d) {
if (page_indices_on_depths_view_[d]->shape[0] == 0) {
continue;
}
if (use_decode_kernel_[d]) {
f_attention_decode_begin_forward_.value()(
d, page_indptr_on_depths_view_[d], last_page_len_on_depths_view_[d], num_qo_heads_,
num_kv_heads_, head_dim_, page_size_,
d, temp_attn_workspace_[d + 1], page_indptr_on_depths_view_[d],
last_page_len_on_depths_view_[d], num_qo_heads_, num_kv_heads_, head_dim_, page_size_,
/*rotary_mode=*/rope_mode_ == RoPEMode::kInline);
} else {
f_attention_prefill_begin_forward_.value()(/*depth=*/d, qo_indptr_on_depths_view_[d],
last_page_len_on_depths_view_[d]->shape[0],
num_qo_heads_, num_kv_heads_);
f_attention_prefill_begin_forward_.value()(
/*depth=*/d, temp_attn_workspace_[d + 1], qo_indptr_on_depths_view_[d],
last_page_len_on_depths_view_[d]->shape[0], num_qo_heads_, num_kv_heads_);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -514,9 +514,8 @@ def tir_rotary(
_1: T.int32,
_2: T.int32,
_3: T.int32,
_4: T.int32,
_4: T.float32,
_5: T.float32,
_6: T.float32,
):
T.func_attr({"tir.is_scheduled": 1})
total_len = T.int32()
Expand Down