Skip to content

Commit 1ff8d9e

Browse files
lfr-0531heyuhhh
authored andcommitted
[None][feat] add gatherKvPageOffsetsKernel (#32)
* add gatherKvPageOffsetsKernel. Signed-off-by: Fanrong Li <[email protected]> * fix. Signed-off-by: Fanrong Li <[email protected]> * fix. Signed-off-by: Fanrong Li <[email protected]> --------- Signed-off-by: Fanrong Li <[email protected]>
1 parent 4896c3c commit 1ff8d9e

File tree

12 files changed

+336
-48
lines changed

12 files changed

+336
-48
lines changed

cpp/tensorrt_llm/common/attentionOp.cpp

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "tensorrt_llm/kernels/gptKernels.h"
2525
#include "tensorrt_llm/kernels/kvCacheUtils.h"
2626
#include "tensorrt_llm/kernels/multiHeadAttentionCommon.h"
27+
#include "tensorrt_llm/kernels/sparseAttentionKernels.h"
2728
#include "tensorrt_llm/kernels/unfusedAttentionKernels.h"
2829
#include "tensorrt_llm/runtime/iBuffer.h"
2930
#include "tensorrt_llm/runtime/utils/debugUtils.h"
@@ -120,9 +121,6 @@ struct FusedQKVMaskedAttentionDispatchParams
120121
bool block_sparse_attention = false;
121122
BlockSparseParams block_sparse_params;
122123
int32_t const* mrope_position_deltas;
123-
int32_t const* sparse_attn_indices;
124-
int32_t const* sparse_attn_offsets;
125-
int32_t num_sparse_attn_indices;
126124
};
127125

128126
template <typename T, typename KVCacheBuffer>
@@ -203,10 +201,6 @@ bool AttentionOp::convertMMHAParamsToXQAParams(tensorrt_llm::kernels::XQAParams&
203201
// Medusa mode will have multiple query tokens.
204202
xqaParams.multi_query_tokens = mIsSpecDecodingEnabled && mUseSpecDecoding;
205203
xqaParams.is_spec_dec_tree = mIsSpecDecTree;
206-
// Sparse attention parameters for XQA
207-
xqaParams.sparse_attn_indices = mRuntimeSparseAttentionParams.sparse_attn_indices;
208-
xqaParams.sparse_attn_offsets = mRuntimeSparseAttentionParams.sparse_attn_offsets;
209-
xqaParams.num_sparse_attn_indices = mRuntimeSparseAttentionParams.num_sparse_attn_indices;
210204

211205
if (mKVCacheQuantMode.hasInt8KvCache())
212206
{
@@ -294,6 +288,9 @@ bool AttentionOp::convertMMHAParamsToXQAParams(tensorrt_llm::kernels::XQAParams&
294288
xqaParams.output_sf = generationsParams.context_buf_sf;
295289
xqaParams.fp4_out_sf_scale = generationsParams.attention_output_sf_scale;
296290
xqaParams.start_token_idx_sf = generationsParams.start_token_idx_sf;
291+
// Parameters for sparse attention
292+
xqaParams.sparse_attn_indices = mRuntimeSparseAttentionParams.sparse_attn_indices;
293+
xqaParams.sparse_attn_offsets = mRuntimeSparseAttentionParams.sparse_attn_offsets;
297294

298295
// Cross attention parameters.
299296
xqaParams.encoder_input_lengths = generationsParams.encoder_input_lengths;
@@ -676,11 +673,6 @@ void fusedQKV_masked_attention_dispatch(Multihead_attention_params<T_MMHA, CROSS
676673

677674
params.multi_processor_count = input_params.multi_processor_count;
678675

679-
// sparse indices and offsets for attention
680-
params.sparse_attn_indices = input_params.sparse_attn_indices;
681-
params.sparse_attn_offsets = input_params.sparse_attn_offsets;
682-
params.num_sparse_attn_indices = input_params.num_sparse_attn_indices;
683-
684676
// cross attn
685677
params.memory_length_per_sample = input_params.memory_length_per_sample;
686678

@@ -825,7 +817,7 @@ size_t AttentionOp::getWorkspaceSizeForContext(nvinfer1::DataType type, int32_t
825817
}
826818

827819
size_t AttentionOp::getWorkspaceSizeForGeneration(nvinfer1::DataType type, int32_t max_num_seq,
828-
int32_t max_attention_window_size, int32_t max_num_tokens) const noexcept
820+
int32_t max_attention_window_size, int32_t max_num_tokens, int32_t max_blocks_per_sequence) const noexcept
829821
{
830822
if (max_num_tokens == 0)
831823
{
@@ -908,14 +900,19 @@ size_t AttentionOp::getWorkspaceSizeForGeneration(nvinfer1::DataType type, int32
908900
size_t const cpMaxPaddedSequenceLength = (batch_beam + mCpSize - 1) / mCpSize * mCpSize;
909901
size_t const cpWorkspaceSize
910902
= mCpSize == 1 ? 0 : (2 * size * cpMaxPaddedSequenceLength * getHeadSize() * (mNumHeads + 2 * mNumKVHeads));
903+
// Two workspaces for sparse attention. One for the sequence lengths, and one for kv block offsets.
904+
size_t const sparse_attn_cache_size = (mUseSparseAttention && mEnableXQA)
905+
? sizeof(int) * (batch_beam + batch_beam * 2 * max_blocks_per_sequence * mNumKVHeads)
906+
: 0;
911907

912-
int const NUM_BUFFERS = 5;
908+
int const NUM_BUFFERS = 6;
913909
size_t workspaces[NUM_BUFFERS];
914910
workspaces[0] = partial_out_size;
915911
workspaces[1] = partial_sum_size;
916912
workspaces[2] = partial_max_size;
917913
workspaces[3] = shift_k_cache_size;
918914
workspaces[4] = cpWorkspaceSize;
915+
workspaces[5] = sparse_attn_cache_size;
919916
generation_workspace_size = tc::calculateTotalWorkspaceSize(workspaces, NUM_BUFFERS);
920917

921918
size_t xqa_workspace_size = 0;
@@ -2275,6 +2272,17 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams<T> const& params, cud
22752272
xqaParams.output = mhaOutput;
22762273
xqaParams.qkv = attention_input;
22772274
}
2275+
if (mUseSparseAttention && std::is_same_v<KVCacheBuffer, KVBlockArray>)
2276+
{
2277+
size_t kv_block_offsets_size = batch_beam * 2 * params.max_blocks_per_sequence * mNumKVHeads;
2278+
size_t seq_lengths_size = batch_beam;
2279+
int* sparse_kv_block_offsets
2280+
= reinterpret_cast<int*>(nextWorkspacePtr(workspace_byte_ptr, offset, kv_block_offsets_size));
2281+
int* sparse_seq_lengths
2282+
= reinterpret_cast<int*>(nextWorkspacePtr(workspace_byte_ptr, offset, seq_lengths_size));
2283+
xqaParams.sparse_kv_block_offsets = sparse_kv_block_offsets;
2284+
xqaParams.sparse_seq_lengths = sparse_seq_lengths;
2285+
}
22782286
mXqaDispatcher->run(xqaParams, kv_cache_buffer, kv_scale_cache_buffer);
22792287
if (mCpSize > 1 && mAttnTpSize > 1 && mAttnCpSize == 1)
22802288
{
@@ -2427,9 +2435,6 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams<T> const& params, cud
24272435
dispatch_params.block_sparse_attention = mMaskType == AttentionMaskType::BLOCKSPARSE;
24282436
dispatch_params.block_sparse_params = mBlockSparseParams;
24292437
dispatch_params.mrope_position_deltas = params.mrope_position_deltas;
2430-
dispatch_params.sparse_attn_indices = mRuntimeSparseAttentionParams.sparse_attn_indices;
2431-
dispatch_params.sparse_attn_offsets = mRuntimeSparseAttentionParams.sparse_attn_offsets;
2432-
dispatch_params.num_sparse_attn_indices = mRuntimeSparseAttentionParams.num_sparse_attn_indices;
24332438

24342439
using DataType = typename SATypeConverter<T>::Type;
24352440
if (!isCrossAttention())

cpp/tensorrt_llm/common/attentionOp.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ class AttentionOp
5656
int32_t cross_kv_length = 0, int32_t max_num_tokens = 0) const noexcept;
5757
// total_num_seq is the sum of beam_width for multiple requests
5858
[[nodiscard]] size_t getWorkspaceSizeForGeneration(nvinfer1::DataType type, int32_t total_num_seq,
59-
int32_t max_attention_window_size, int32_t max_num_tokens) const noexcept;
59+
int32_t max_attention_window_size, int32_t max_num_tokens, int32_t max_blocks_per_sequence) const noexcept;
6060

6161
template <typename T>
6262
class EnqueueParams
@@ -488,7 +488,7 @@ class AttentionOp
488488
mPosShiftEnabled, mPagedContextFMHA, mFP8ContextFMHA, mFP8AttenOutput, mFP8ContextMLA, mFP8GenerationMLA,
489489
mChunkPrefillBufferBatchSize, mDenseContextFMHA, mHasFullAttentionMask, mIsSpecDecodingEnabled,
490490
mUseSpecDecoding, mIsSpecDecTree, mSpecDecodingIsGenerationLengthVariable, mSpecDecodingMaxGenerationLength,
491-
mIsMLAEnabled, mIsGenerationMLA, mUseGenFlashMLA, mMLAParams.data(), mCpSize, mCpRank, mCpGroup,
491+
mIsMLAEnabled, mIsGenerationMLA, mUseGenFlashMLA, mUseSparseAttention, mMLAParams.data(), mCpSize, mCpRank, mCpGroup,
492492
mNumAttnHeads, mNumAttnKVHeads, mNumKVHeadsOrigin, mAttnTpSize, mAttnTpRank, mAttnCpSize, mAttnCpRank,
493493
mUlyssesMQABroadcast, mEnableContextFMHA, mFMHAForceFP32Acc, mMultiBlockMode, mEnableXQA, mUseKVCache,
494494
mSkipAttn, mFuseFp4Quant, mRuntimeSparseAttentionParams.data(), mNbMultiBlockSemaphores, mAttentionChunkSize.value_or(-1));

cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -208,11 +208,6 @@ struct Multihead_attention_params_base
208208
// threadblock counter to identify the complete of partial attention computations
209209
int* block_counter = nullptr;
210210

211-
// sparse indices and offsets for attention calculation
212-
int32_t const* sparse_attn_indices = nullptr;
213-
int32_t const* sparse_attn_offsets = nullptr;
214-
int32_t num_sparse_attn_indices = 0;
215-
216211
int const* memory_length_per_sample = nullptr;
217212
int32_t const* mrope_position_deltas = nullptr;
218213
};

cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/xqaParams.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,12 @@ struct XQAParams
108108

109109
// for cross attention
110110
int32_t const* encoder_input_lengths = nullptr;
111-
int32_t num_sparse_attn_indices = 0;
112-
int32_t const* sparse_attn_indices = nullptr;
113-
int32_t const* sparse_attn_offsets = nullptr;
111+
112+
// sparse attention parameters
113+
int32_t* sparse_attn_indices = nullptr;
114+
int32_t* sparse_attn_offsets = nullptr;
115+
int* sparse_seq_lengths = nullptr;
116+
int* sparse_kv_block_offsets = nullptr;
114117

115118
cudaStream_t stream = 0;
116119

@@ -182,7 +185,6 @@ struct XQAParams
182185
<< "is_fp8_output :" << (is_fp8_output ? "true" : "false") << std ::endl
183186
<< "fp8_out_scale :" << fp8_out_scale << std ::endl
184187
<< "encoder_input_lengths: " << encoder_input_lengths << std::endl
185-
<< "num_sparse_attn_indices :" << num_sparse_attn_indices << std ::endl
186188
<< "sparse_attn_indices :" << sparse_attn_indices << std ::endl
187189
<< "sparse_attn_offsets :" << sparse_attn_offsets << std ::endl
188190
<< "stream :" << stream;
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
#include "tensorrt_llm/kernels/sparseAttentionKernels.h"
2+
3+
namespace tensorrt_llm
4+
{
5+
namespace kernels
6+
{
7+
__global__ void gatherKvPageOffsetsKernel(
8+
int* output_kv_page_offsets, // [num_head_kv, batch_size, 2, max_num_pages_per_seq]
9+
int* output_seq_lengths, // [batch_size]
10+
int const* kv_page_offsets, // [batch_size, 2, max_num_pages_per_seq]
11+
int const* seq_lengths, // [batch_size]
12+
SparseAttentionParams const& sparse_params)
13+
{
14+
// Each CUDA block processes one sequence from the batch.
15+
int const head_idx = blockIdx.x;
16+
int const batch_idx = blockIdx.y;
17+
if (batch_idx >= sparse_params.batch_size)
18+
{
19+
return;
20+
}
21+
22+
// Get the range of sparse indices.
23+
int const start_offset = sparse_params.sparse_attn_offsets[batch_idx];
24+
int const end_offset = sparse_params.sparse_attn_offsets[batch_idx + 1];
25+
int const num_sparse_pages = end_offset - start_offset;
26+
27+
// Get the base memory offset. shape: [batch_size, 2, max_num_pages_per_seq]
28+
int const max_num_pages_per_seq = sparse_params.max_num_pages_per_seq;
29+
size_t const src_base_offset = (size_t) batch_idx * 2 * max_num_pages_per_seq;
30+
size_t const dst_base_offset
31+
= (size_t) head_idx * sparse_params.batch_size * 2 * max_num_pages_per_seq + src_base_offset;
32+
33+
// Set the sequence length.
34+
if (threadIdx.x == 0)
35+
{
36+
int const tokens_per_page = sparse_params.tokens_per_page;
37+
int const num_pages = (seq_lengths[batch_idx] + tokens_per_page - 1) / tokens_per_page;
38+
output_seq_lengths[batch_idx] = seq_lengths[batch_idx] - (num_pages - num_sparse_pages) * tokens_per_page;
39+
}
40+
41+
// Perform the gather operation.
42+
for (int i = threadIdx.x; i < num_sparse_pages; i += blockDim.x)
43+
{
44+
// Get the source idx and offset.
45+
int const sparse_idx_global = (start_offset + i) * sparse_params.num_head_kv + head_idx;
46+
int const src_idx = sparse_params.sparse_attn_indices[sparse_idx_global];
47+
size_t const src_offset_dim0 = src_base_offset + 0 * max_num_pages_per_seq + src_idx;
48+
size_t const src_offset_dim1 = src_base_offset + 1 * max_num_pages_per_seq + src_idx;
49+
50+
// Get the destination offset.
51+
size_t const dst_offset_dim0 = dst_base_offset + 0 * max_num_pages_per_seq + i;
52+
size_t const dst_offset_dim1 = dst_base_offset + 1 * max_num_pages_per_seq + i;
53+
54+
// Perform the gather operation: read from the sparse location and write to the dense location.
55+
output_kv_page_offsets[dst_offset_dim0] = kv_page_offsets[src_offset_dim0];
56+
output_kv_page_offsets[dst_offset_dim1] = kv_page_offsets[src_offset_dim1];
57+
}
58+
}
59+
60+
// Host-side launcher function
61+
void invokeGatherKvPageOffsets(int* output_kv_page_offsets, int* output_seq_lengths, int const* kv_page_offsets,
62+
int const* seq_lengths, SparseAttentionParams const& sparse_params, cudaStream_t stream)
63+
{
64+
// The grid.
65+
dim3 grid(sparse_params.num_head_kv, sparse_params.batch_size, 1);
66+
// The block.
67+
dim3 block(256, 1, 1);
68+
69+
// Launch the kernel.
70+
gatherKvPageOffsetsKernel<<<grid, block, 0, stream>>>(
71+
output_kv_page_offsets, output_seq_lengths, kv_page_offsets, seq_lengths, sparse_params);
72+
}
73+
} // namespace kernels
74+
} // namespace tensorrt_llm
Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
#pragma once
22

3+
#include <cuda_runtime.h>
4+
#include <sstream>
5+
36
namespace tensorrt_llm
47
{
58
namespace kernels
@@ -13,26 +16,37 @@ struct SparseAttentionParams
1316
int32_t* sparse_attn_offsets{nullptr}; // [num_generations + 1]
1417

1518
int32_t num_sparse_kv_indices{0};
16-
int32_t num_sparse_attn_indices{0};
19+
20+
// Scalars
21+
int32_t batch_size{0};
22+
int32_t num_head_kv{0};
23+
int32_t tokens_per_page{0};
24+
int32_t max_num_pages_per_seq{0};
1725

1826
std::string toString() const
1927
{
2028
std::stringstream ss;
21-
ss << "num_sparse_kv_indices: " << this->num_sparse_kv_indices << std::endl;
22-
ss << "num_sparse_attn_indices: " << this->num_sparse_attn_indices << std::endl;
23-
ss << "sparse_kv_indices: " << this->sparse_kv_indices << std::endl;
24-
ss << "sparse_attn_indices: " << this->sparse_attn_indices << std::endl;
25-
ss << "sparse_kv_offsets: " << this->sparse_kv_offsets << std::endl;
26-
ss << "sparse_attn_offsets: " << this->sparse_attn_offsets << std::endl;
29+
ss << "num_sparse_kv_indices: " << this->num_sparse_kv_indices << std::endl
30+
<< "sparse_kv_indices: " << this->sparse_kv_indices << std::endl
31+
<< "sparse_attn_indices: " << this->sparse_attn_indices << std::endl
32+
<< "sparse_kv_offsets: " << this->sparse_kv_offsets << std::endl
33+
<< "sparse_attn_offsets: " << this->sparse_attn_offsets << std::endl
34+
<< "batch_size: " << this->batch_size << std::endl
35+
<< "num_head_kv: " << this->num_head_kv << std::endl
36+
<< "tokens_per_page: " << this->tokens_per_page << std::endl
37+
<< "max_num_pages_per_seq: " << this->max_num_pages_per_seq << std::endl;
2738
return ss.str();
2839
}
2940

3041
auto data() const
3142
{
3243
return std::make_tuple(sparse_kv_indices, sparse_attn_indices, sparse_kv_offsets, sparse_attn_offsets,
33-
num_sparse_kv_indices, num_sparse_attn_indices);
44+
num_sparse_kv_indices, batch_size, num_head_kv, tokens_per_page, max_num_pages_per_seq);
3445
}
3546
};
3647

48+
void invokeGatherKvPageOffsets(int* output_kv_page_offsets, int* output_seq_lengths, int const* kv_page_offsets,
49+
int const* seq_lengths, SparseAttentionParams const& sparse_params, cudaStream_t stream);
50+
3751
} // namespace kernels
3852
} // namespace tensorrt_llm

cpp/tensorrt_llm/kernels/xqaDispatcher.cpp

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "xqaDispatcher.h"
1818
#include "tensorrt_llm/common/cudaUtils.h"
1919
#include "tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplCommon.h"
20+
#include "tensorrt_llm/kernels/sparseAttentionKernels.h"
2021
#include "tensorrt_llm/kernels/unfusedAttentionKernels.h"
2122
#include <cstdint>
2223

@@ -404,11 +405,30 @@ void XqaDispatcher::runImpl(
404405
// Otherwise, always enable the persistent scheduler for better performance.
405406
tllmRunnerParams.mTileScheduler = params.multi_block_mode ? TileScheduler::Static : TileScheduler::Persistent;
406407

408+
// The sequence lengths for K/V.
409+
tllmRunnerParams.seqLensKvPtr = params.cross_attention ? params.encoder_input_lengths : params.sequence_lengths;
410+
407411
// Q buffer.
408412
tllmRunnerParams.qPtr = xqa_q_input_ptr;
409413
// KV buffer
414+
bool use_sparse_attention = (params.sparse_attn_indices != nullptr && params.sparse_attn_offsets != nullptr);
410415
if constexpr (std::is_same_v<KVCacheBuffer, KVBlockArray>)
411416
{
417+
// Gather kv page offsets for sparse attention.
418+
if (use_sparse_attention)
419+
{
420+
SparseAttentionParams sparse_params;
421+
sparse_params.sparse_attn_indices = params.sparse_attn_indices;
422+
sparse_params.sparse_attn_offsets = params.sparse_attn_offsets;
423+
sparse_params.batch_size = batch_beam_size;
424+
sparse_params.num_head_kv = num_kv_heads;
425+
sparse_params.tokens_per_page = kv_cache_buffer.mTokensPerBlock;
426+
sparse_params.max_num_pages_per_seq = kv_cache_buffer.mMaxBlocksPerSeq;
427+
invokeGatherKvPageOffsets(params.sparse_kv_block_offsets, params.sparse_seq_lengths,
428+
launchParams.cu_kv_seq_lens, params.sequence_lengths, sparse_params, params.stream);
429+
sync_check_cuda_error(params.stream);
430+
}
431+
412432
// Paged KV
413433
tllmRunnerParams.mQkvLayout = QkvLayout::PagedKv;
414434
tllmRunnerParams.kvPtr = kv_cache_buffer.mPrimaryPoolPtr;
@@ -419,6 +439,7 @@ void XqaDispatcher::runImpl(
419439
}
420440
else
421441
{
442+
TLLM_CHECK_WITH_INFO(!use_sparse_attention, "Sparse attention is not supported for KVLinearBuffer.");
422443
static_assert(std::is_same_v<KVCacheBuffer, KVLinearBuffer>);
423444
// Contiguous KV
424445
tllmRunnerParams.mQkvLayout = QkvLayout::ContiguousKv;
@@ -437,8 +458,6 @@ void XqaDispatcher::runImpl(
437458
tllmRunnerParams.scaleSoftmaxLog2Ptr
438459
= reinterpret_cast<float const*>(launchParams.bmm1_scale_ptr + kIdxScaleSoftmaxLog2Ptr);
439460
tllmRunnerParams.oSfScalePtr = params.fp4_out_sf_scale;
440-
// The sequence lengths for K/V.
441-
tllmRunnerParams.seqLensKvPtr = params.cross_attention ? params.encoder_input_lengths : params.sequence_lengths;
442461

443462
tllmRunnerParams.oPtr = params.output;
444463
tllmRunnerParams.oSfPtr = params.output_sf;

cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -572,11 +572,15 @@ size_t GPTAttentionPlugin::getWorkspaceSize(nvinfer1::PluginTensorDesc const* in
572572
= isCrossAttention() ? cross_kv_length : (useKVCache() ? inputs[getIdx(IdxEntry::CACHE_INDIR)].dims.d[2] : 0);
573573
int const max_num_tokens
574574
= mRemovePadding ? inputs[getIdx(IdxEntry::QKV_TENSOR)].dims.d[0] : max_num_seq * max_context_length;
575+
auto const& kvCacheBlockOffsetsShape = inputs[getIdx(IdxEntry::KV_CACHE_BLOCK_OFFSETS)].dims;
576+
int const max_blocks_per_sequence
577+
= (useKVCache() && mPagedKVCache) ? inputs[getIdx(IdxEntry::KV_CACHE_BLOCK_OFFSETS)].dims.d[3] : 0;
578+
575579
size_t const context_workspace_size
576580
= getWorkspaceSizeForContext(type, max_num_seq, max_context_length, cross_kv_length, max_num_tokens);
577581

578-
size_t const generation_workspace_size
579-
= getWorkspaceSizeForGeneration(type, max_num_seq, max_kv_cache_length, max_num_tokens);
582+
size_t const generation_workspace_size = getWorkspaceSizeForGeneration(
583+
type, max_num_seq, max_kv_cache_length, max_num_tokens, max_blocks_per_sequence);
580584

581585
size_t attention_input_workspace_size = 0;
582586

0 commit comments

Comments
 (0)