|
24 | 24 | #include "tensorrt_llm/kernels/gptKernels.h" |
25 | 25 | #include "tensorrt_llm/kernels/kvCacheUtils.h" |
26 | 26 | #include "tensorrt_llm/kernels/multiHeadAttentionCommon.h" |
| 27 | +#include "tensorrt_llm/kernels/sparseAttentionKernels.h" |
27 | 28 | #include "tensorrt_llm/kernels/unfusedAttentionKernels.h" |
28 | 29 | #include "tensorrt_llm/runtime/iBuffer.h" |
29 | 30 | #include "tensorrt_llm/runtime/utils/debugUtils.h" |
@@ -120,9 +121,6 @@ struct FusedQKVMaskedAttentionDispatchParams |
120 | 121 | bool block_sparse_attention = false; |
121 | 122 | BlockSparseParams block_sparse_params; |
122 | 123 | int32_t const* mrope_position_deltas; |
123 | | - int32_t const* sparse_attn_indices; |
124 | | - int32_t const* sparse_attn_offsets; |
125 | | - int32_t num_sparse_attn_indices; |
126 | 124 | }; |
127 | 125 |
|
128 | 126 | template <typename T, typename KVCacheBuffer> |
@@ -203,10 +201,6 @@ bool AttentionOp::convertMMHAParamsToXQAParams(tensorrt_llm::kernels::XQAParams& |
203 | 201 | // Medusa mode will have multiple query tokens. |
204 | 202 | xqaParams.multi_query_tokens = mIsSpecDecodingEnabled && mUseSpecDecoding; |
205 | 203 | xqaParams.is_spec_dec_tree = mIsSpecDecTree; |
206 | | - // Sparse attention parameters for XQA |
207 | | - xqaParams.sparse_attn_indices = mRuntimeSparseAttentionParams.sparse_attn_indices; |
208 | | - xqaParams.sparse_attn_offsets = mRuntimeSparseAttentionParams.sparse_attn_offsets; |
209 | | - xqaParams.num_sparse_attn_indices = mRuntimeSparseAttentionParams.num_sparse_attn_indices; |
210 | 204 |
|
211 | 205 | if (mKVCacheQuantMode.hasInt8KvCache()) |
212 | 206 | { |
@@ -294,6 +288,9 @@ bool AttentionOp::convertMMHAParamsToXQAParams(tensorrt_llm::kernels::XQAParams& |
294 | 288 | xqaParams.output_sf = generationsParams.context_buf_sf; |
295 | 289 | xqaParams.fp4_out_sf_scale = generationsParams.attention_output_sf_scale; |
296 | 290 | xqaParams.start_token_idx_sf = generationsParams.start_token_idx_sf; |
| 291 | + // Parameters for sparse attention |
| 292 | + xqaParams.sparse_attn_indices = mRuntimeSparseAttentionParams.sparse_attn_indices; |
| 293 | + xqaParams.sparse_attn_offsets = mRuntimeSparseAttentionParams.sparse_attn_offsets; |
297 | 294 |
|
298 | 295 | // Cross attention parameters. |
299 | 296 | xqaParams.encoder_input_lengths = generationsParams.encoder_input_lengths; |
@@ -676,11 +673,6 @@ void fusedQKV_masked_attention_dispatch(Multihead_attention_params<T_MMHA, CROSS |
676 | 673 |
|
677 | 674 | params.multi_processor_count = input_params.multi_processor_count; |
678 | 675 |
|
679 | | - // sparse indices and offsets for attention |
680 | | - params.sparse_attn_indices = input_params.sparse_attn_indices; |
681 | | - params.sparse_attn_offsets = input_params.sparse_attn_offsets; |
682 | | - params.num_sparse_attn_indices = input_params.num_sparse_attn_indices; |
683 | | - |
684 | 676 | // cross attn |
685 | 677 | params.memory_length_per_sample = input_params.memory_length_per_sample; |
686 | 678 |
|
@@ -825,7 +817,7 @@ size_t AttentionOp::getWorkspaceSizeForContext(nvinfer1::DataType type, int32_t |
825 | 817 | } |
826 | 818 |
|
827 | 819 | size_t AttentionOp::getWorkspaceSizeForGeneration(nvinfer1::DataType type, int32_t max_num_seq, |
828 | | - int32_t max_attention_window_size, int32_t max_num_tokens) const noexcept |
| 820 | + int32_t max_attention_window_size, int32_t max_num_tokens, int32_t max_blocks_per_sequence) const noexcept |
829 | 821 | { |
830 | 822 | if (max_num_tokens == 0) |
831 | 823 | { |
@@ -908,14 +900,19 @@ size_t AttentionOp::getWorkspaceSizeForGeneration(nvinfer1::DataType type, int32 |
908 | 900 | size_t const cpMaxPaddedSequenceLength = (batch_beam + mCpSize - 1) / mCpSize * mCpSize; |
909 | 901 | size_t const cpWorkspaceSize |
910 | 902 | = mCpSize == 1 ? 0 : (2 * size * cpMaxPaddedSequenceLength * getHeadSize() * (mNumHeads + 2 * mNumKVHeads)); |
| 903 | + // Two workspaces for sparse attention. One for the sequence lengths, and one for kv block offsets. |
| 904 | + size_t const sparse_attn_cache_size = (mUseSparseAttention && mEnableXQA) |
| 905 | + ? sizeof(int) * (batch_beam + batch_beam * 2 * max_blocks_per_sequence * mNumKVHeads) |
| 906 | + : 0; |
911 | 907 |
|
912 | | - int const NUM_BUFFERS = 5; |
| 908 | + int const NUM_BUFFERS = 6; |
913 | 909 | size_t workspaces[NUM_BUFFERS]; |
914 | 910 | workspaces[0] = partial_out_size; |
915 | 911 | workspaces[1] = partial_sum_size; |
916 | 912 | workspaces[2] = partial_max_size; |
917 | 913 | workspaces[3] = shift_k_cache_size; |
918 | 914 | workspaces[4] = cpWorkspaceSize; |
| 915 | + workspaces[5] = sparse_attn_cache_size; |
919 | 916 | generation_workspace_size = tc::calculateTotalWorkspaceSize(workspaces, NUM_BUFFERS); |
920 | 917 |
|
921 | 918 | size_t xqa_workspace_size = 0; |
@@ -2275,6 +2272,17 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams<T> const& params, cud |
2275 | 2272 | xqaParams.output = mhaOutput; |
2276 | 2273 | xqaParams.qkv = attention_input; |
2277 | 2274 | } |
| 2275 | + if (mUseSparseAttention && std::is_same_v<KVCacheBuffer, KVBlockArray>) |
| 2276 | + { |
| 2277 | + size_t kv_block_offsets_size = batch_beam * 2 * params.max_blocks_per_sequence * mNumKVHeads; |
| 2278 | + size_t seq_lengths_size = batch_beam; |
| 2279 | + int* sparse_kv_block_offsets |
| 2280 | + = reinterpret_cast<int*>(nextWorkspacePtr(workspace_byte_ptr, offset, kv_block_offsets_size)); |
| 2281 | + int* sparse_seq_lengths |
| 2282 | + = reinterpret_cast<int*>(nextWorkspacePtr(workspace_byte_ptr, offset, seq_lengths_size)); |
| 2283 | + xqaParams.sparse_kv_block_offsets = sparse_kv_block_offsets; |
| 2284 | + xqaParams.sparse_seq_lengths = sparse_seq_lengths; |
| 2285 | + } |
2278 | 2286 | mXqaDispatcher->run(xqaParams, kv_cache_buffer, kv_scale_cache_buffer); |
2279 | 2287 | if (mCpSize > 1 && mAttnTpSize > 1 && mAttnCpSize == 1) |
2280 | 2288 | { |
@@ -2427,9 +2435,6 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams<T> const& params, cud |
2427 | 2435 | dispatch_params.block_sparse_attention = mMaskType == AttentionMaskType::BLOCKSPARSE; |
2428 | 2436 | dispatch_params.block_sparse_params = mBlockSparseParams; |
2429 | 2437 | dispatch_params.mrope_position_deltas = params.mrope_position_deltas; |
2430 | | - dispatch_params.sparse_attn_indices = mRuntimeSparseAttentionParams.sparse_attn_indices; |
2431 | | - dispatch_params.sparse_attn_offsets = mRuntimeSparseAttentionParams.sparse_attn_offsets; |
2432 | | - dispatch_params.num_sparse_attn_indices = mRuntimeSparseAttentionParams.num_sparse_attn_indices; |
2433 | 2438 |
|
2434 | 2439 | using DataType = typename SATypeConverter<T>::Type; |
2435 | 2440 | if (!isCrossAttention()) |
|
0 commit comments