Skip to content

Commit 578dbc8

Browse files
authored
feat: chunked prefill for MLA (Blackwell) (#4651)
Signed-off-by: Mingyang Jiang <[email protected]>
1 parent 3fc5754 commit 578dbc8

File tree

19 files changed

+2320
-40
lines changed

19 files changed

+2320
-40
lines changed

cpp/tensorrt_llm/common/attentionOp.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1713,6 +1713,7 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
17131713
fmhaParams.oSfScalePtr = params.attention_output_sf_scale;
17141714
fmhaParams.stream = stream;
17151715
fmhaParams.forceFp32Acc = mFMHAForceFP32Acc;
1716+
fmhaParams.softmaxStatsPtr = params.softmaxStatsPtr;
17161717

17171718
if (mAttentionChunkSize)
17181719
{

cpp/tensorrt_llm/common/attentionOp.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,9 @@ class AttentionOp
124124
int32_t num_encoder_tokens = 0;
125125
kernels::MlaParams<T>* mla_param = nullptr;
126126

127+
// For MLA chunked prefill
128+
void* softmaxStatsPtr = nullptr;
129+
127130
std::string enqueueContextParamsToString() const
128131
{
129132
// variables from the params coming from the runtime
@@ -173,6 +176,7 @@ class AttentionOp
173176
ss << "cross_kv_length: " << this->cross_kv_length << std::endl;
174177
ss << "encoder_input_lengths: " << this->encoder_input_lengths << std::endl;
175178
ss << "num_encoder_tokens: " << this->num_encoder_tokens << std::endl;
179+
ss << "softmaxStatsPtr: " << this->softmaxStatsPtr << std::endl;
176180
return ss.str();
177181
}
178182
};

cpp/tensorrt_llm/kernels/fmhaDispatcher.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,8 @@ void FmhaDispatcher::run(MHARunnerParams runnerParams)
197197
// Set it to INT_MAX as the kv cache pageOffsets will ensure that there is no out-of-bounds access.
198198
tllmRunnerParams.mNumPagesInMemPool = INT_MAX;
199199
tllmRunnerParams.mSfStartTokenIdx = 0;
200+
// For mla chunked prefill
201+
tllmRunnerParams.softmaxStatsPtr = reinterpret_cast<float2*>(runnerParams.softmaxStatsPtr);
200202
tllmRunnerParams.stream = runnerParams.stream;
201203
mTllmGenFMHARunner->run(tllmRunnerParams);
202204
}

cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cu

Lines changed: 383 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
/*
2+
* Copyright (c) 2019-2025, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include "tensorrt_llm/kernels/kvCacheUtils.h"
18+
19+
namespace tensorrt_llm
20+
{
21+
namespace kernels
22+
{
23+
// merged_attn [q_total_len, H=128, D=128] (T)
24+
// merged_softmax_sum [q_total_len, H, 2] (float), the first part is the max value for each
25+
// row of P = QK^T, the second part is the softmax sum
26+
// if merge_op[b] == 0, we just skip this batch, if merge_op[b] == 1, we merge the pre-attn and curr-attn, if
27+
// merge_op[b]
28+
// == 2, we only copy curr_attn and curr_softmax_sum to merged_attn and merged_softmax_sum
29+
template <typename T>
30+
void invokeMergeAttnWithSoftmax(T* merged_attn, float* merged_softmax_stats, T const* pre_attn,
31+
float const* pre_softmax_stats, T const* curr_attn, float const* curr_softmax_stats, int const batch_size,
32+
int64_t const* cu_q_seq_len, int max_q_seq_len, int64_t const* merge_op, int const num_heads, int const head_size,
33+
cudaStream_t stream);
34+
35+
// load single chunk kv from kv_cache for each request
36+
template <typename T>
37+
void invokeMLALoadChunkedKV(T* output_kv_ptr, T* output_k_pe_ptr, KVBlockArray const& kv_cache, int const num_contexts,
38+
int64_t const* cu_ctx_chunked_len, int lora_size, int rope_size, int chunked_size, int chunked_idx,
39+
cudaStream_t stream);
40+
41+
// output_kv {B, 2, ceil(chunked_size / kv_cache_tokens_per_block), h, kv_cache_tokens_per_block, d}, padding with
42+
// zero
43+
// kv {total_token, 2, H, uncompressed_h=128} 0 for k and 1 for v, k_pe {total_token, h=1, rope_h}
44+
// input kv and k_pe can be cached tokens or uncached tokens
45+
template <typename T>
46+
void invokeMLASetChunkedKV(T* output_kv, T const* kv, T const* k_pe, int const batch_size, int const max_seq_len,
47+
int const num_heads, int uncompressed_head_size, int rope_size, int64_t const* cu_seq_lens,
48+
int const kv_cache_tokens_per_block, cudaStream_t stream);
49+
} // namespace kernels
50+
} // namespace tensorrt_llm

cpp/tensorrt_llm/kernels/mlaKernels.cu

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -761,7 +761,7 @@ __global__ void setPagedKVCacheForMLAKernel(T* output, T const* k_ptr, T const*
761761
// q {total_uncached_tokens, h, d_nope + d_rope}
762762
// latent_cache {total_uncached_tokens, d_k + d_rope}
763763
template <typename T, typename TCache, int BLOCK_SIZE, int K_DIM, int ROPE_DIM>
764-
__global__ void applyMLARopeAppendPagedKVAssignQKernel(KVBlockArray kv_cache, T* q_ptr, T const* latent_cache_ptr,
764+
__global__ void applyMLARopeAppendPagedKVAssignQKernel(KVBlockArray kv_cache, T* q_ptr, T* latent_cache_ptr,
765765
int64_t const* cu_ctx_cached_kv_lens, int64_t const* cu_seq_lens, int const max_input_uncached_seq_len,
766766
float2 const* cos_sin_cache, size_t head_num, int nope_size, float const* kv_scale_orig_quant_ptr)
767767
{
@@ -851,6 +851,10 @@ __global__ void applyMLARopeAppendPagedKVAssignQKernel(KVBlockArray kv_cache, T*
851851
if constexpr (std::is_same_v<TCache, T>)
852852
{
853853
reinterpret_cast<VecT*>(kDst)[inBlockIdx] = data;
854+
// copy to latent_cache (for chunked prefill, it will not load kv cache for uncached k_pe)
855+
auto const src_k_global_offset
856+
= static_cast<size_t>(global_token_idx) * (K_DIM + ROPE_DIM) + K_DIM;
857+
*reinterpret_cast<VecT*>(&latent_cache_ptr[src_k_global_offset + head_dim_idx]) = data;
854858
}
855859
else if constexpr (std::is_same_v<TCache, __nv_fp8_e4m3>)
856860
{
@@ -980,10 +984,10 @@ void invokeMLASetPagedKV(T* output, T const* k_ptr, T const* v_ptr, T const* k_p
980984
}
981985

982986
template <typename T, typename TCache>
983-
void invokeMLARopeAppendPagedKVAssignQ(KVBlockArray& kv_cache, T* q_ptr, T const* latent_cache_ptr,
984-
int const num_requests, int64_t const* cu_ctx_cached_kv_lens, int64_t const* cu_seq_lens,
985-
int const max_input_uncached_seq_len, float2 const* cos_sin_cache, size_t head_num, int nope_size, int rope_size,
986-
int lora_size, float const* kv_scale_orig_quant_ptr, cudaStream_t stream)
987+
void invokeMLARopeAppendPagedKVAssignQ(KVBlockArray& kv_cache, T* q_ptr, T* latent_cache_ptr, int const num_requests,
988+
int64_t const* cu_ctx_cached_kv_lens, int64_t const* cu_seq_lens, int const max_input_uncached_seq_len,
989+
float2 const* cos_sin_cache, size_t head_num, int nope_size, int rope_size, int lora_size,
990+
float const* kv_scale_orig_quant_ptr, cudaStream_t stream)
987991
{
988992
dim3 grid(int(tensorrt_llm::common::divUp(max_input_uncached_seq_len, 32)), num_requests, head_num + 1 + 8);
989993
TLLM_CHECK_WITH_INFO(lora_size == 512, "lora_size should be equal to %d", 512);
@@ -1012,7 +1016,7 @@ INSTANTIATE_MLA_ROPE(__nv_bfloat16, KVLinearBuffer);
10121016
int const num_contexts, int64_t const* cu_ctx_cached_kv_lens, int const max_input_seq_len, \
10131017
int const lora_size, int const rope_size, float const* kv_scale_quant_orig_ptr, cudaStream_t stream); \
10141018
template void invokeMLARopeAppendPagedKVAssignQ<T, TCache>(KVBlockArray & kv_cache, T * q_ptr, \
1015-
T const* latent_cache_ptr, int const num_requests, int64_t const* cu_ctx_cached_kv_lens, \
1019+
T * latent_cache_ptr, int const num_requests, int64_t const* cu_ctx_cached_kv_lens, \
10161020
int64_t const* cu_seq_lens, int const max_input_uncached_seq_len, float2 const* cos_sin_cache, \
10171021
size_t head_num, int nope_size, int rope_size, int lora_size, float const* kv_scale_orig_quant_ptr, \
10181022
cudaStream_t stream);

cpp/tensorrt_llm/kernels/mlaKernels.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -106,10 +106,10 @@ void invokeMLASetPagedKV(T* output, T const* k_ptr, T const* v_ptr, T const* k_p
106106
int kv_cache_tokens_per_block, int64_t kv_token_stride, cudaStream_t stream);
107107

108108
template <typename T, typename TCache>
109-
void invokeMLARopeAppendPagedKVAssignQ(KVBlockArray& kv_cache, T* q_ptr, T const* latent_cache_ptr,
110-
int const num_requests, int64_t const* cu_ctx_cached_kv_lens, int64_t const* cu_seq_lens,
111-
int const max_input_uncached_seq_len, float2 const* cos_sin_cache, size_t head_num, int nope_size, int rope_size,
112-
int lora_size, float const* kv_scale_orig_quant_ptr, cudaStream_t stream);
109+
void invokeMLARopeAppendPagedKVAssignQ(KVBlockArray& kv_cache, T* q_ptr, T* latent_cache_ptr, int const num_requests,
110+
int64_t const* cu_ctx_cached_kv_lens, int64_t const* cu_seq_lens, int const max_input_uncached_seq_len,
111+
float2 const* cos_sin_cache, size_t head_num, int nope_size, int rope_size, int lora_size,
112+
float const* kv_scale_orig_quant_ptr, cudaStream_t stream);
113113

114114
} // namespace kernels
115115
} // namespace tensorrt_llm

cpp/tensorrt_llm/thop/attentionOp.cpp

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,8 @@ class RunnerBase
7777
torch::optional<torch::Tensor> q_pe, torch::optional<torch::Tensor> block_ids_per_seq,
7878
torch::optional<torch::Tensor> mrope_rotary_cos_sin, torch::optional<torch::Tensor> mrope_position_deltas,
7979
torch::optional<torch::Tensor> mla_context_paged_kv,
80-
torch::optional<torch::Tensor> mla_context_kv_cache_block_offsets) const
80+
torch::optional<torch::Tensor> mla_context_kv_cache_block_offsets,
81+
torch::optional<torch::Tensor> softmax_stats_tensor) const
8182
= 0;
8283
};
8384

@@ -127,7 +128,8 @@ class Runner : public RunnerBase
127128
torch::optional<torch::Tensor> q_pe, torch::optional<torch::Tensor> block_ids_per_seq,
128129
torch::optional<torch::Tensor> mrope_rotary_cos_sin, torch::optional<torch::Tensor> mrope_position_deltas,
129130
torch::optional<torch::Tensor> mla_context_paged_kv,
130-
torch::optional<torch::Tensor> mla_context_kv_cache_block_offsets) const override
131+
torch::optional<torch::Tensor> mla_context_kv_cache_block_offsets,
132+
torch::optional<torch::Tensor> softmax_stats_tensor) const override
131133
{
132134
auto stream = at::cuda::getCurrentCUDAStream(qkv.get_device());
133135
T* attention_input = static_cast<T*>(qkv.slice(0, token_offset).data_ptr());
@@ -279,6 +281,11 @@ class Runner : public RunnerBase
279281
AttentionOp::EnqueueContextParams<T> enqueue_params{common_enqueue_params};
280282
enqueue_params.host_block_offsets = host_block_offsets;
281283
enqueue_params.batch_size = num_seqs;
284+
if (softmax_stats_tensor.has_value())
285+
{
286+
enqueue_params.softmaxStatsPtr = static_cast<float2*>(softmax_stats_tensor.value().data_ptr());
287+
}
288+
282289
if (op.isMLAEnabled())
283290
{
284291
mla_params.cache_seq_lens = sequence_lengths_ptr;
@@ -385,7 +392,7 @@ void attention_inplace(torch::Tensor q, torch::optional<torch::Tensor> k, torch:
385392
std::optional<int64_t> qk_rope_head_dim, std::optional<int64_t> v_head_dim,
386393
torch::optional<torch::Tensor> mrope_rotary_cos_sin, torch::optional<torch::Tensor> mrope_position_deltas,
387394
std::optional<torch::Tensor> mla_context_paged_kv, std::optional<torch::Tensor> mla_context_kv_cache_block_offsets,
388-
std::optional<int64_t> attention_chunk_size)
395+
std::optional<int64_t> attention_chunk_size, std::optional<torch::Tensor> softmax_stats_tensor)
389396
{
390397
TLLM_LOG_TRACE("Attention op starts at layer %d", layer_idx);
391398
// Use these tensors to infer if the attention is using KV cache
@@ -603,7 +610,7 @@ void attention_inplace(torch::Tensor q, torch::optional<torch::Tensor> k, torch:
603610
host_kv_cache_block_offsets, host_kv_cache_pool_pointers, host_kv_cache_pool_mapping, cache_indirection,
604611
kv_scale_orig_quant, kv_scale_quant_orig, out_scale, rotary_inv_freq, rotary_cos_sin, latent_cache, q_pe,
605612
block_ids_per_seq, mrope_rotary_cos_sin, mrope_position_deltas, mla_context_paged_kv,
606-
mla_context_kv_cache_block_offsets);
613+
mla_context_kv_cache_block_offsets, softmax_stats_tensor);
607614
}
608615

609616
if ((num_generations > 0) && (attn_input_type != AttentionInputType::ContextOnly))
@@ -619,7 +626,7 @@ void attention_inplace(torch::Tensor q, torch::optional<torch::Tensor> k, torch:
619626
host_kv_cache_block_offsets, host_kv_cache_pool_pointers, host_kv_cache_pool_mapping, cache_indirection,
620627
kv_scale_orig_quant, kv_scale_quant_orig, out_scale, rotary_inv_freq, rotary_cos_sin, latent_cache, q_pe,
621628
block_ids_per_seq, mrope_rotary_cos_sin, mrope_position_deltas, mla_context_paged_kv,
622-
mla_context_kv_cache_block_offsets);
629+
mla_context_kv_cache_block_offsets, softmax_stats_tensor);
623630
}
624631

625632
TLLM_LOG_TRACE("Attention op stops at layer %d", layer_idx);
@@ -742,6 +749,7 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m)
742749
", Tensor? mla_context_paged_kv"
743750
", Tensor? mla_context_kv_cache_block_offsets"
744751
", int? attention_chunk_size"
752+
", Tensor? softmax_stats_tensor"
745753
") -> ()");
746754

747755
m.def("attention_supports_nvfp4_output", &torch_ext::attention_supports_nvfp4_output);

0 commit comments

Comments
 (0)