Skip to content

Commit 8a9cdac

Browse files
committed
Add RocketKV trtllm attention backend
Signed-off-by: yuhangh <[email protected]> Update sparse attention parameters passing logic Signed-off-by: yuhangh <[email protected]> fix rebase breaks Signed-off-by: yuhangh <[email protected]> Signed-off-by: Fanrong Li <[email protected]> [None][feat] add gatherKvPageOffsetsKernel (#32) * add gatherKvPageOffsetsKernel. Signed-off-by: Fanrong Li <[email protected]> * fix. Signed-off-by: Fanrong Li <[email protected]> * fix. Signed-off-by: Fanrong Li <[email protected]> --------- Signed-off-by: Fanrong Li <[email protected]> Add sparse kv indices write kernel & fix several bugs Signed-off-by: yuhangh <[email protected]> fix for rebase Signed-off-by: yuhangh <[email protected]> Signed-off-by: Fanrong Li <[email protected]> [None][feat] integrate block sparse attention kernels (#33) * integrate block sparse attention kernels. Signed-off-by: Fanrong Li <[email protected]> * fix. Signed-off-by: Fanrong Li <[email protected]> * Support num_kv_heads in seq_len & fix several workspace size bugs Signed-off-by: yuhangh <[email protected]> * update block sparse attention kernel to support per-head kv_len. Signed-off-by: Fanrong Li <[email protected]> * minor fix Signed-off-by: yuhangh <[email protected]> * update kernel meta info. * add more block sparse kernels. * disable rope_fusion for sparse attention. Signed-off-by: Fanrong Li <[email protected]> * fix block sparse attention kernels. * update block sparse attention kernel. Signed-off-by: Fanrong Li <[email protected]> * fix workspace issue. Signed-off-by: Fanrong Li <[email protected]> * minor fix Signed-off-by: yuhangh <[email protected]> * fix gatherKvPageOffsetsKernel. Signed-off-by: Fanrong Li <[email protected]> * remove cuda stream sync. Signed-off-by: Fanrong Li <[email protected]> --------- Signed-off-by: Fanrong Li <[email protected]> Signed-off-by: yuhangh <[email protected]> Co-authored-by: yuhangh <[email protected]> Signed-off-by: Fanrong Li <[email protected]> [None][feat] change the sparse indices format and update the gatherKvPageOffsetsKe… (#34) * change the sparse indices format and update the gatherKvPageOffsetsKernel. Signed-off-by: Fanrong Li <[email protected]> * update kv write & optimize logic of using tllmgen kernels Signed-off-by: yuhangh <[email protected]> --------- Signed-off-by: Fanrong Li <[email protected]> Signed-off-by: yuhangh <[email protected]> Co-authored-by: yuhangh <[email protected]> Signed-off-by: Fanrong Li <[email protected]> add paged kt cache (1st commit). Signed-off-by: Fanrong Li <[email protected]> minnor fix. Signed-off-by: Fanrong Li <[email protected]> fix _single_request_update_kt_cache for vanilla RocketKV. Signed-off-by: Fanrong Li <[email protected]> add paged kt cache to rocketkv trtllm. Signed-off-by: Fanrong Li <[email protected]> fix _single_request_update_kt_cache for trtllm RocketKV. Signed-off-by: Fanrong Li <[email protected]> fix k_snap length. Signed-off-by: Fanrong Li <[email protected]> fix memory issue when using paged kt cache. Signed-off-by: Fanrong Li <[email protected]> fix rebase breaks Signed-off-by: yuhangh <[email protected]> fix rebase bug. Signed-off-by: Fanrong Li <[email protected]> fix rebase bug. Signed-off-by: Fanrong Li <[email protected]> update block sparse attention kernel. Signed-off-by: Fanrong Li <[email protected]> fix params issue Signed-off-by: yuhangh <[email protected]> Signed-off-by: Fanrong Li <[email protected]> [None][feat] Do sparse attention functional clean (#43) * fix several bugs & adjust some code Signed-off-by: yuhangh <[email protected]> * minor code clean Signed-off-by: yuhangh <[email protected]> * Add simple unittest for rocketkv Signed-off-by: yuhangh <[email protected]> * Adjustment for sparse attention params and example Signed-off-by: yuhangh <[email protected]> * fix bugs introduced by last commit Signed-off-by: yuhangh <[email protected]> * Optimize Xqa_params and num_sparse_kv_tokens Signed-off-by: yuhangh <[email protected]> * Fix gather kernel & minor adjustment Signed-off-by: yuhangh <[email protected]> * Rename sparse_attention_params in xqa_params Signed-off-by: yuhangh <[email protected]> * minor Signed-off-by: yuhangh <[email protected]> --------- Signed-off-by: yuhangh <[email protected]> Signed-off-by: Fanrong Li <[email protected]> [None][feat] Update trtllm-gen fmha kernels and remove block sparse cubins (#44) * rm sparse kernels. Signed-off-by: Fanrong Li <[email protected]> * update new kernel. Signed-off-by: Fanrong Li <[email protected]> * update trtllm-gen fmha. Signed-off-by: Fanrong Li <[email protected]> --------- Signed-off-by: Fanrong Li <[email protected]> fix rebase conflicts Signed-off-by: yuhangh <[email protected]> minor fix Signed-off-by: yuhangh <[email protected]> pre-commit fix Signed-off-by: yuhangh <[email protected]> Signed-off-by: Fanrong Li <[email protected]> [None][fix] update trtllm sparse attention interface (#45) * update trtllm sparse attention interface. Signed-off-by: Fanrong Li <[email protected]> * fix interface. Signed-off-by: Fanrong Li <[email protected]> --------- Signed-off-by: Fanrong Li <[email protected]> fix rocketkv interface. (#47) Signed-off-by: Fanrong Li <[email protected]>
1 parent 0879ca6 commit 8a9cdac

File tree

72 files changed

+2672
-584
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

72 files changed

+2672
-584
lines changed

cpp/tensorrt_llm/common/attentionOp.cpp

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "tensorrt_llm/kernels/gptKernels.h"
2525
#include "tensorrt_llm/kernels/kvCacheUtils.h"
2626
#include "tensorrt_llm/kernels/multiHeadAttentionCommon.h"
27+
#include "tensorrt_llm/kernels/sparseAttentionKernels.h"
2728
#include "tensorrt_llm/kernels/unfusedAttentionKernels.h"
2829
#include "tensorrt_llm/runtime/iBuffer.h"
2930
#include "tensorrt_llm/runtime/utils/debugUtils.h"
@@ -287,6 +288,9 @@ bool AttentionOp::convertMMHAParamsToXQAParams(tensorrt_llm::kernels::XQAParams&
287288
xqaParams.output_sf = generationsParams.context_buf_sf;
288289
xqaParams.fp4_out_sf_scale = generationsParams.attention_output_sf_scale;
289290
xqaParams.start_token_idx_sf = generationsParams.start_token_idx_sf;
291+
// Parameters for sparse attention
292+
xqaParams.sparse_params = mRuntimeSparseAttentionParams;
293+
xqaParams.use_sparse_attention = useTllmGenSparseAttention();
290294

291295
// Cross attention parameters.
292296
xqaParams.encoder_input_lengths = generationsParams.encoder_input_lengths;
@@ -813,7 +817,7 @@ size_t AttentionOp::getWorkspaceSizeForContext(nvinfer1::DataType type, int32_t
813817
}
814818

815819
size_t AttentionOp::getWorkspaceSizeForGeneration(nvinfer1::DataType type, int32_t max_num_seq,
816-
int32_t max_attention_window_size, int32_t max_num_tokens) const noexcept
820+
int32_t max_attention_window_size, int32_t max_num_tokens, int32_t max_blocks_per_sequence) const noexcept
817821
{
818822
if (max_num_tokens == 0)
819823
{
@@ -909,11 +913,15 @@ size_t AttentionOp::getWorkspaceSizeForGeneration(nvinfer1::DataType type, int32
909913
size_t xqa_workspace_size = 0;
910914
if (mEnableXQA)
911915
{
912-
int const XQA_NUM_BUFFERS = 7;
916+
int const XQA_NUM_BUFFERS = 8;
913917
size_t xqa_workspaces[XQA_NUM_BUFFERS];
914918
size_t const cu_seqlens_size = sizeof(int) * (batch_beam + 1);
915919
size_t const cu_kv_seqlens_size = sizeof(int) * (batch_beam + 1);
916920
size_t const rotary_inv_freq_size = sizeof(float) * batch_beam * mRotaryEmbeddingDim / 2;
921+
// Two workspaces for sparse attention. One for the sequence lengths, and one for kv block offsets.
922+
size_t const sparse_attn_cache_size = useTllmGenSparseAttention()
923+
? sizeof(int) * (batch_beam + batch_beam * 2 * max_blocks_per_sequence) * mNumKVHeads
924+
: 0;
917925
xqa_workspaces[0] = cu_seqlens_size;
918926
xqa_workspaces[1] = cu_kv_seqlens_size;
919927
xqa_workspaces[2] = rotary_inv_freq_size;
@@ -922,7 +930,8 @@ size_t AttentionOp::getWorkspaceSizeForGeneration(nvinfer1::DataType type, int32
922930
// Scales used for trtllm-gen kernels.
923931
xqa_workspaces[4] = sizeof(float) * 2;
924932
xqa_workspaces[5] = sizeof(float);
925-
xqa_workspaces[6] = mXqaDispatcher->getWorkspaceSize(
933+
xqa_workspaces[6] = sparse_attn_cache_size;
934+
xqa_workspaces[7] = mXqaDispatcher->getWorkspaceSize(
926935
std::min<uint32_t>(mSpecDecodingMaxGenerationLength * max_num_seq, max_num_tokens));
927936
xqa_workspace_size
928937
= tc::calculateTotalWorkspaceSize(xqa_workspaces, XQA_NUM_BUFFERS, mXqaDispatcher->getWorkspaceAlignment());
@@ -1647,6 +1656,10 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
16471656
preprocessingParams.spec_decoding_position_offsets = nullptr;
16481657
preprocessingParams.logn_scaling = params.logn_scaling_ptr;
16491658

1659+
// Sparse KV write
1660+
preprocessingParams.sparse_kv_indices = mRuntimeSparseAttentionParams.sparse_kv_indices;
1661+
preprocessingParams.sparse_kv_offsets = mRuntimeSparseAttentionParams.sparse_kv_offsets;
1662+
16501663
// Scalars
16511664
preprocessingParams.batch_size = params.batch_size;
16521665
preprocessingParams.max_input_seq_len = params.input_seq_length;
@@ -1676,6 +1689,8 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
16761689

16771690
preprocessingParams.rotary_vision_start = mVisionStart;
16781691
preprocessingParams.rotary_vision_length = mVisionLength;
1692+
preprocessingParams.is_last_chunk
1693+
= !mAttentionChunkSize.has_value() || (params.input_seq_length == params.max_past_kv_length);
16791694

16801695
{
16811696
std::string const beforeRopeStr = "ctx attention before RoPE at layer " + std::to_string(mLayerIdx);
@@ -1841,6 +1856,12 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
18411856
gatherInBuffer, params, cu_q_seqlens, cu_cp_partial_seqlens, stream);
18421857
sync_check_cuda_error(stream);
18431858
}
1859+
1860+
if (!mIsMLAEnabled) // Only for non-MLA attention
1861+
{
1862+
invokeKvCachePostprocessing(preprocessingParams, stream);
1863+
sync_check_cuda_error(stream);
1864+
}
18441865
}
18451866
else
18461867
{

cpp/tensorrt_llm/common/attentionOp.h

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "tensorrt_llm/kernels/gptKernels.h"
2727
#include "tensorrt_llm/kernels/kvCacheUtils.h"
2828
#include "tensorrt_llm/kernels/mlaKernels.h"
29+
#include "tensorrt_llm/kernels/sparseAttentionKernels.h"
2930
#include "tensorrt_llm/kernels/xqaDispatcher.h"
3031
#include <cassert>
3132
#include <set>
@@ -55,7 +56,7 @@ class AttentionOp
5556
int32_t cross_kv_length = 0, int32_t max_num_tokens = 0) const noexcept;
5657
// total_num_seq is the sum of beam_width for multiple requests
5758
[[nodiscard]] size_t getWorkspaceSizeForGeneration(nvinfer1::DataType type, int32_t total_num_seq,
58-
int32_t max_attention_window_size, int32_t max_num_tokens) const noexcept;
59+
int32_t max_attention_window_size, int32_t max_num_tokens, int32_t max_blocks_per_sequence) const noexcept;
5960

6061
template <typename T>
6162
class EnqueueParams
@@ -156,14 +157,20 @@ class AttentionOp
156157
ss << "max_cyclic_attention_window_size: " << this->max_cyclic_attention_window_size << std::endl;
157158
ss << "can_use_one_more_block: " << (this->can_use_one_more_block ? "true" : "false") << std::endl;
158159
ss << "sink_token_length: " << this->sink_token_length << std::endl;
159-
ss << "context_lengths: "
160-
<< *(runtime::ITensor::wrap((void*) this->context_lengths, nvinfer1::DataType::kINT32,
161-
runtime::ITensor::makeShape({batch_size})))
162-
<< std::endl;
163-
ss << "sequence_lengths: "
164-
<< *(runtime::ITensor::wrap((void*) this->sequence_lengths, nvinfer1::DataType::kINT32,
165-
runtime::ITensor::makeShape({batch_size})))
166-
<< std::endl;
160+
if (this->context_lengths && batch_size > 0)
161+
{
162+
ss << "context_lengths: "
163+
<< *(runtime::ITensor::wrap((void*) this->context_lengths, nvinfer1::DataType::kINT32,
164+
runtime::ITensor::makeShape({batch_size})))
165+
<< std::endl;
166+
}
167+
if (this->sequence_lengths && batch_size > 0)
168+
{
169+
ss << "sequence_lengths: "
170+
<< *(runtime::ITensor::wrap((void*) this->sequence_lengths, nvinfer1::DataType::kINT32,
171+
runtime::ITensor::makeShape({batch_size})))
172+
<< std::endl;
173+
}
167174
ss << "kv_scale_orig_quant: " << this->kv_scale_orig_quant << std::endl;
168175
ss << "kv_scale_quant_orig: " << this->kv_scale_quant_orig << std::endl;
169176
ss << "attention_output_orig_quant: " << this->attention_output_orig_quant << std::endl;
@@ -348,6 +355,16 @@ class AttentionOp
348355
return mIsMLAEnabled;
349356
}
350357

358+
[[nodiscard]] bool useSparseAttention() const
359+
{
360+
return mUseSparseAttention && mPagedKVCache && mEnableXQA;
361+
}
362+
363+
[[nodiscard]] bool useTllmGenSparseAttention() const
364+
{
365+
return mUseTllmGenSparseAttention && useSparseAttention();
366+
}
367+
351368
[[nodiscard]] int smVersion() const
352369
{
353370
return mSM;
@@ -427,6 +444,8 @@ class AttentionOp
427444
bool mIsMLAEnabled = false;
428445
bool mIsGenerationMLA = false;
429446
bool mUseGenFlashMLA = false;
447+
bool mUseSparseAttention = false;
448+
bool mUseTllmGenSparseAttention = false;
430449
tensorrt_llm::kernels::MlaMetaParams mMLAParams;
431450
int mCpSize = 1;
432451
int mCpRank = 0;
@@ -454,6 +473,8 @@ class AttentionOp
454473
// Whether to fuse FP4 quant into attention kernel.
455474
bool mFuseFp4Quant = false;
456475

476+
kernels::SparseAttentionParams mRuntimeSparseAttentionParams;
477+
457478
// This is implementation details which we want to save when serializing, but not expose as
458479
// a plugin field or a constructor parameter
459480
int32_t mNbMultiBlockSemaphores = 0;
@@ -473,10 +494,11 @@ class AttentionOp
473494
mPosShiftEnabled, mPagedContextFMHA, mFP8ContextFMHA, mFP8AttenOutput, mFP8ContextMLA, mFP8GenerationMLA,
474495
mChunkPrefillBufferBatchSize, mDenseContextFMHA, mHasFullAttentionMask, mIsSpecDecodingEnabled,
475496
mUseSpecDecoding, mIsSpecDecTree, mSpecDecodingIsGenerationLengthVariable, mSpecDecodingMaxGenerationLength,
476-
mIsMLAEnabled, mIsGenerationMLA, mUseGenFlashMLA, mMLAParams.data(), mCpSize, mCpRank, mCpGroup,
477-
mNumAttnHeads, mNumAttnKVHeads, mNumKVHeadsOrigin, mAttnTpSize, mAttnTpRank, mAttnCpSize, mAttnCpRank,
478-
mUlyssesMQABroadcast, mEnableContextFMHA, mFMHAForceFP32Acc, mMultiBlockMode, mEnableXQA, mUseKVCache,
479-
mSkipAttn, mFuseFp4Quant, mNbMultiBlockSemaphores, mAttentionChunkSize.value_or(-1));
497+
mIsMLAEnabled, mIsGenerationMLA, mUseGenFlashMLA, mUseSparseAttention, mUseTllmGenSparseAttention,
498+
mMLAParams.data(), mCpSize, mCpRank, mCpGroup, mNumAttnHeads, mNumAttnKVHeads, mNumKVHeadsOrigin,
499+
mAttnTpSize, mAttnTpRank, mAttnCpSize, mAttnCpRank, mUlyssesMQABroadcast, mEnableContextFMHA,
500+
mFMHAForceFP32Acc, mMultiBlockMode, mEnableXQA, mUseKVCache, mSkipAttn, mFuseFp4Quant,
501+
mRuntimeSparseAttentionParams.data(), mNbMultiBlockSemaphores, mAttentionChunkSize.value_or(-1));
480502
};
481503

482504
private:

cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplCommon.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,8 @@ struct XQALaunchParam
233233
float* bmm2_scale_ptr = nullptr;
234234
int32_t* semaphores = nullptr;
235235
void* scratch = nullptr;
236+
void* sparse_kv_block_offsets = nullptr;
237+
int32_t* sparse_seq_lengths = nullptr;
236238
};
237239

238240
// Setup launch params and ioScratch. ioScratch is for RoPE and output type conversion.
@@ -266,6 +268,9 @@ void buildXQALaunchParams(XQALaunchParam<KVCacheBuffer>& launchParams, void*& in
266268
const size_t cu_kv_seqlens_size = sizeof(int) * (batch_beam_size + 1);
267269
const size_t rotary_inv_freq_size = sizeof(float) * batch_beam_size * params.rotary_embedding_dim / 2;
268270
const size_t tokens_info_size = sizeof(int2) * params.total_num_input_tokens;
271+
const size_t kv_block_offsets_size
272+
= sizeof(int) * batch_beam_size * 2 * params.max_blocks_per_sequence * params.num_kv_heads;
273+
const size_t seq_lengths_size = sizeof(int) * batch_beam_size * params.num_kv_heads;
269274
launchParams.cu_seq_lens = reinterpret_cast<int*>(workspace);
270275
workspace = tensorrt_llm::common::nextWorkspacePtrWithAlignment(workspace, cu_seqlens_size);
271276
launchParams.cu_kv_seq_lens = reinterpret_cast<int*>(workspace);
@@ -281,6 +286,14 @@ void buildXQALaunchParams(XQALaunchParam<KVCacheBuffer>& launchParams, void*& in
281286
workspace = tensorrt_llm::common::nextWorkspacePtrWithAlignment(workspace, bmm1_scale_size);
282287
launchParams.bmm2_scale_ptr = reinterpret_cast<float*>(workspace);
283288
workspace = tensorrt_llm::common::nextWorkspacePtrWithAlignment(workspace, bmm2_scale_size);
289+
// Used for block sparse attention
290+
if (params.use_sparse_attention)
291+
{
292+
launchParams.sparse_kv_block_offsets = reinterpret_cast<void*>(workspace);
293+
workspace = tensorrt_llm::common::nextWorkspacePtrWithAlignment(workspace, kv_block_offsets_size);
294+
launchParams.sparse_seq_lengths = reinterpret_cast<int*>(workspace);
295+
workspace = tensorrt_llm::common::nextWorkspacePtrWithAlignment(workspace, seq_lengths_size);
296+
}
284297
inputScratch = workspace;
285298
if (hasOutputScratch)
286299
{

cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/xqaParams.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "tensorrt_llm/common/quantization.h"
1818
#include "tensorrt_llm/kernels/gptKernels.h"
1919
#include "tensorrt_llm/kernels/multiHeadAttentionCommon.h"
20+
#include "tensorrt_llm/kernels/sparseAttentionKernels.h"
2021

2122
namespace tensorrt_llm
2223
{
@@ -109,6 +110,10 @@ struct XQAParams
109110
// for cross attention
110111
int32_t const* encoder_input_lengths = nullptr;
111112

113+
// sparse attention parameters
114+
SparseAttentionParams sparse_params;
115+
bool use_sparse_attention = false;
116+
112117
cudaStream_t stream = 0;
113118

114119
std::string toString() const
@@ -179,6 +184,8 @@ struct XQAParams
179184
<< "is_fp8_output :" << (is_fp8_output ? "true" : "false") << std ::endl
180185
<< "fp8_out_scale :" << fp8_out_scale << std ::endl
181186
<< "encoder_input_lengths: " << encoder_input_lengths << std::endl
187+
<< "sparse_params: " << sparse_params.toString() << std::endl
188+
<< "use_sparse_attention :" << (use_sparse_attention ? "true" : "false") << std ::endl
182189
<< "stream :" << stream;
183190

184191
return ss.str();
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
#include "tensorrt_llm/kernels/sparseAttentionKernels.h"
2+
#include <cub/cub.cuh>
3+
4+
namespace tensorrt_llm
5+
{
6+
namespace kernels
7+
{
8+
template <int THREADS_PER_BLOCK>
9+
__global__ void gatherKvPageOffsetsKernel(
10+
int32_t* output_kv_page_offsets, // [num_head_kv, batch_size, 2, max_num_pages_per_seq]
11+
int32_t* output_seq_lengths, // [num_head_kv, batch_size]
12+
int32_t const* kv_page_offsets, // [batch_size, 2, max_num_pages_per_seq]
13+
int32_t const* seq_lengths, // [batch_size]
14+
SparseAttentionParams const sparse_params, int32_t const batch_size, int32_t const tokens_per_page,
15+
int32_t const max_num_pages_per_seq)
16+
{
17+
// Each CUDA block processes one sequence from the batch for one head.
18+
int32_t const head_idx = blockIdx.x;
19+
int32_t const batch_idx = blockIdx.y;
20+
if (batch_idx >= batch_size)
21+
{
22+
return;
23+
}
24+
25+
// Shared memory for reduction.
26+
__shared__ typename cub::BlockReduce<Pair, THREADS_PER_BLOCK>::TempStorage temp_storage;
27+
28+
// Get the range of sparse indices and the sequence length.
29+
int32_t const start_offset = sparse_params.sparse_attn_offsets[batch_idx];
30+
int32_t const end_offset = sparse_params.sparse_attn_offsets[batch_idx + 1];
31+
int32_t const total_pages = sparse_params.sparse_attn_offsets[batch_size];
32+
int32_t const num_sparse_pages = end_offset - start_offset;
33+
int32_t const original_seq_len = seq_lengths[batch_idx];
34+
35+
// Get global sparse index.
36+
int32_t const sparse_idx_global = head_idx * total_pages + start_offset;
37+
38+
// Get the base memory offset. shape: [batch_size, 2, max_num_pages_per_seq]
39+
size_t const src_base_offset = (size_t) batch_idx * 2 * max_num_pages_per_seq;
40+
size_t const dst_base_offset = (size_t) head_idx * batch_size * 2 * max_num_pages_per_seq + src_base_offset;
41+
42+
// Initialize the local max page index and number of valid pages.
43+
int32_t local_max_page_index = -1;
44+
int32_t local_num_valid_pages = 0;
45+
46+
// Perform the gather operation.
47+
for (int32_t i = threadIdx.x; i < num_sparse_pages; i += blockDim.x)
48+
{
49+
// Get the source idx and offset.
50+
int32_t const src_idx = sparse_params.sparse_attn_indices[sparse_idx_global + i];
51+
if (src_idx < 0)
52+
{
53+
continue;
54+
}
55+
56+
// Update the local max page index.
57+
local_max_page_index = max(local_max_page_index, src_idx);
58+
local_num_valid_pages++;
59+
60+
// Get the source and destination offsets.
61+
size_t const src_offset_dim0 = src_base_offset + 0 * max_num_pages_per_seq + src_idx;
62+
size_t const src_offset_dim1 = src_base_offset + 1 * max_num_pages_per_seq + src_idx;
63+
size_t const dst_offset_dim0 = dst_base_offset + 0 * max_num_pages_per_seq + i;
64+
size_t const dst_offset_dim1 = dst_base_offset + 1 * max_num_pages_per_seq + i;
65+
66+
// Perform the gather operation: read from the sparse location and write to the dense location.
67+
output_kv_page_offsets[dst_offset_dim0] = kv_page_offsets[src_offset_dim0];
68+
output_kv_page_offsets[dst_offset_dim1] = kv_page_offsets[src_offset_dim1];
69+
}
70+
71+
// Reduce the local max page indices and number of valid pages.
72+
Pair local_pair = {local_max_page_index, local_num_valid_pages};
73+
Pair result = cub::BlockReduce<Pair, THREADS_PER_BLOCK>(temp_storage).Reduce(local_pair, PairReduceOp());
74+
75+
// Update sequence length for this head and batch.
76+
if (threadIdx.x == 0)
77+
{
78+
int32_t const max_page_index = result.max_val;
79+
int32_t const num_valid_pages = result.sum_val;
80+
int32_t const ori_valid_pages = (original_seq_len + tokens_per_page - 1) / tokens_per_page;
81+
size_t const seq_len_offset = (size_t) head_idx * batch_size + batch_idx;
82+
if (num_valid_pages > 0)
83+
{
84+
int32_t seq_len = original_seq_len - (ori_valid_pages - num_valid_pages) * tokens_per_page;
85+
int32_t seq_len_remain = original_seq_len % tokens_per_page;
86+
if (max_page_index != ori_valid_pages - 1 && seq_len_remain != 0)
87+
{
88+
seq_len += tokens_per_page - seq_len_remain;
89+
}
90+
output_seq_lengths[seq_len_offset] = seq_len;
91+
}
92+
else
93+
{
94+
output_seq_lengths[seq_len_offset] = 0;
95+
}
96+
}
97+
}
98+
99+
// Host-side launcher function
100+
void invokeGatherKvPageOffsets(int32_t* output_kv_page_offsets, int32_t* output_seq_lengths,
101+
int32_t const* kv_page_offsets, int32_t const* seq_lengths, SparseAttentionParams const sparse_params,
102+
int32_t const batch_size, int32_t const num_head_kv, int32_t const tokens_per_page,
103+
int32_t const max_num_pages_per_seq, cudaStream_t stream)
104+
{
105+
// The grid.
106+
dim3 grid(num_head_kv, batch_size, 1);
107+
// The block.
108+
dim3 block(256, 1, 1);
109+
// Shared memory size.
110+
size_t smem_size = sizeof(Pair) * 256;
111+
112+
// Launch the kernel.
113+
gatherKvPageOffsetsKernel<256><<<grid, block, smem_size, stream>>>(output_kv_page_offsets, output_seq_lengths,
114+
kv_page_offsets, seq_lengths, sparse_params, batch_size, tokens_per_page, max_num_pages_per_seq);
115+
}
116+
} // namespace kernels
117+
} // namespace tensorrt_llm

0 commit comments

Comments
 (0)