Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
ec96ea9
draft: add merge attn kernel and related unit test
jmydurant May 7, 2025
3360967
fix: fix merge attn bug
jmydurant May 8, 2025
b12b4d6
fix: fix unit test bugs
jmydurant May 8, 2025
1a0a5a6
fix: fix illegal head idx bug
jmydurant May 9, 2025
64eac67
draft: chunked prefill for MLA
jmydurant May 20, 2025
b27245f
chore: change softmax stats format to [B, S, H, 2]
jmydurant May 26, 2025
93b4bb3
feature: remove unnecessary split and copy operation for set mla kernel
jmydurant May 26, 2025
ca41a72
feature: remove unnecessary copy operation for load chunked kv cache
jmydurant May 29, 2025
cf4a779
WIP: change the work flow
jmydurant May 30, 2025
3d95052
draft: change the work flow. It can adapt to any kv length
jmydurant Jun 3, 2025
beec65d
fix: fix compile err and warning
jmydurant Jun 3, 2025
64f6784
fix: fix some bugs, add more test cases, pass cpp UT.
jmydurant Jun 4, 2025
4b66367
chore: modify pytorch pipeline to support different cached kv len
jmydurant Jun 5, 2025
b01730b
fix: do not apply mask within loop
jmydurant Jun 6, 2025
82a0a1a
test: add related pytest code
jmydurant Jun 6, 2025
39c5943
fix: fix some pytorch test bugs
jmydurant Jun 10, 2025
62dba9a
fix: fix bug when attn softmax stats max is not multiplied by bmm scale
jmydurant Jun 11, 2025
27aaa9f
test: add test case when we only open chunked prefilled
jmydurant Jun 16, 2025
7d482af
fix: try to resolve rebase conflict. copy k_pe to latent cache
jmydurant Jun 16, 2025
73ff8f8
chore: update by code review
jmydurant Jun 16, 2025
eb21bc8
chore: update by code review, move temp tensor to meta data
jmydurant Jun 17, 2025
d2d369b
fix: softmax stats can't be set as params of meta data
jmydurant Jun 18, 2025
7f92de4
test: modify test case, remove mtp setting
jmydurant Jun 18, 2025
f897353
chore: add license
jmydurant Jun 18, 2025
9176976
fix: correct test list
jmydurant Jun 18, 2025
8f7e58b
fix: fix after rebase, see 13eef642e6f7a909515 for details
jmydurant Jun 18, 2025
7da4988
fix: correct test list
jmydurant Jun 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cpp/tensorrt_llm/common/attentionOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1713,6 +1713,7 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
fmhaParams.oSfScalePtr = params.attention_output_sf_scale;
fmhaParams.stream = stream;
fmhaParams.forceFp32Acc = mFMHAForceFP32Acc;
fmhaParams.softmaxStatsPtr = params.softmaxStatsPtr;

if (mAttentionChunkSize)
{
Expand Down
4 changes: 4 additions & 0 deletions cpp/tensorrt_llm/common/attentionOp.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ class AttentionOp
int32_t num_encoder_tokens = 0;
kernels::MlaParams<T>* mla_param = nullptr;

// For MLA chunked prefill
void* softmaxStatsPtr = nullptr;

std::string enqueueContextParamsToString() const
{
// variables from the params coming from the runtime
Expand Down Expand Up @@ -173,6 +176,7 @@ class AttentionOp
ss << "cross_kv_length: " << this->cross_kv_length << std::endl;
ss << "encoder_input_lengths: " << this->encoder_input_lengths << std::endl;
ss << "num_encoder_tokens: " << this->num_encoder_tokens << std::endl;
ss << "softmaxStatsPtr: " << this->softmaxStatsPtr << std::endl;
return ss.str();
}
};
Expand Down
2 changes: 2 additions & 0 deletions cpp/tensorrt_llm/kernels/fmhaDispatcher.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,8 @@ void FmhaDispatcher::run(MHARunnerParams runnerParams)
// Set it to INT_MAX as the kv cache pageOffsets will ensure that there is no out-of-bounds access.
tllmRunnerParams.mNumPagesInMemPool = INT_MAX;
tllmRunnerParams.mSfStartTokenIdx = 0;
// For mla chunked prefill
tllmRunnerParams.softmaxStatsPtr = reinterpret_cast<float2*>(runnerParams.softmaxStatsPtr);
tllmRunnerParams.stream = runnerParams.stream;
mTllmGenFMHARunner->run(tllmRunnerParams);
}
Expand Down
383 changes: 383 additions & 0 deletions cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cu

Large diffs are not rendered by default.

50 changes: 50 additions & 0 deletions cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Copyright (c) 2019-2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "tensorrt_llm/kernels/kvCacheUtils.h"

namespace tensorrt_llm
{
namespace kernels
{
// merged_attn [q_total_len, H=128, D=128] (T)
// merged_softmax_sum [q_total_len, H, 2] (float), the first part is the max value for each
// row of P = QK^T, the second part is the softmax sum
// if merge_op[b] == 0, we just skip this batch, if merge_op[b] == 1, we merge the pre-attn and curr-attn, if
// merge_op[b]
// == 2, we only copy curr_attn and curr_softmax_sum to merged_attn and merged_softmax_sum
template <typename T>
void invokeMergeAttnWithSoftmax(T* merged_attn, float* merged_softmax_stats, T const* pre_attn,
float const* pre_softmax_stats, T const* curr_attn, float const* curr_softmax_stats, int const batch_size,
int64_t const* cu_q_seq_len, int max_q_seq_len, int64_t const* merge_op, int const num_heads, int const head_size,
cudaStream_t stream);

// load single chunk kv from kv_cache for each request
template <typename T>
void invokeMLALoadChunkedKV(T* output_kv_ptr, T* output_k_pe_ptr, KVBlockArray const& kv_cache, int const num_contexts,
int64_t const* cu_ctx_chunked_len, int lora_size, int rope_size, int chunked_size, int chunked_idx,
cudaStream_t stream);

// output_kv {B, 2, ceil(chunked_size / kv_cache_tokens_per_block), h, kv_cache_tokens_per_block, d}, padding with
// zero
// kv {total_token, 2, H, uncompressed_h=128} 0 for k and 1 for v, k_pe {total_token, h=1, rope_h}
// input kv and k_pe can be cached tokens or uncached tokens
template <typename T>
void invokeMLASetChunkedKV(T* output_kv, T const* kv, T const* k_pe, int const batch_size, int const max_seq_len,
int const num_heads, int uncompressed_head_size, int rope_size, int64_t const* cu_seq_lens,
int const kv_cache_tokens_per_block, cudaStream_t stream);
} // namespace kernels
} // namespace tensorrt_llm
16 changes: 10 additions & 6 deletions cpp/tensorrt_llm/kernels/mlaKernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -761,7 +761,7 @@ __global__ void setPagedKVCacheForMLAKernel(T* output, T const* k_ptr, T const*
// q {total_uncached_tokens, h, d_nope + d_rope}
// latent_cache {total_uncached_tokens, d_k + d_rope}
template <typename T, typename TCache, int BLOCK_SIZE, int K_DIM, int ROPE_DIM>
__global__ void applyMLARopeAppendPagedKVAssignQKernel(KVBlockArray kv_cache, T* q_ptr, T const* latent_cache_ptr,
__global__ void applyMLARopeAppendPagedKVAssignQKernel(KVBlockArray kv_cache, T* q_ptr, T* latent_cache_ptr,
int64_t const* cu_ctx_cached_kv_lens, int64_t const* cu_seq_lens, int const max_input_uncached_seq_len,
float2 const* cos_sin_cache, size_t head_num, int nope_size, float const* kv_scale_orig_quant_ptr)
{
Expand Down Expand Up @@ -851,6 +851,10 @@ __global__ void applyMLARopeAppendPagedKVAssignQKernel(KVBlockArray kv_cache, T*
if constexpr (std::is_same_v<TCache, T>)
{
reinterpret_cast<VecT*>(kDst)[inBlockIdx] = data;
// copy to latent_cache (for chunked prefill, it will not load kv cache for uncached k_pe)
auto const src_k_global_offset
= static_cast<size_t>(global_token_idx) * (K_DIM + ROPE_DIM) + K_DIM;
*reinterpret_cast<VecT*>(&latent_cache_ptr[src_k_global_offset + head_dim_idx]) = data;
}
else if constexpr (std::is_same_v<TCache, __nv_fp8_e4m3>)
{
Expand Down Expand Up @@ -980,10 +984,10 @@ void invokeMLASetPagedKV(T* output, T const* k_ptr, T const* v_ptr, T const* k_p
}

template <typename T, typename TCache>
void invokeMLARopeAppendPagedKVAssignQ(KVBlockArray& kv_cache, T* q_ptr, T const* latent_cache_ptr,
int const num_requests, int64_t const* cu_ctx_cached_kv_lens, int64_t const* cu_seq_lens,
int const max_input_uncached_seq_len, float2 const* cos_sin_cache, size_t head_num, int nope_size, int rope_size,
int lora_size, float const* kv_scale_orig_quant_ptr, cudaStream_t stream)
void invokeMLARopeAppendPagedKVAssignQ(KVBlockArray& kv_cache, T* q_ptr, T* latent_cache_ptr, int const num_requests,
int64_t const* cu_ctx_cached_kv_lens, int64_t const* cu_seq_lens, int const max_input_uncached_seq_len,
float2 const* cos_sin_cache, size_t head_num, int nope_size, int rope_size, int lora_size,
float const* kv_scale_orig_quant_ptr, cudaStream_t stream)
{
dim3 grid(int(tensorrt_llm::common::divUp(max_input_uncached_seq_len, 32)), num_requests, head_num + 1 + 8);
TLLM_CHECK_WITH_INFO(lora_size == 512, "lora_size should be equal to %d", 512);
Expand Down Expand Up @@ -1012,7 +1016,7 @@ INSTANTIATE_MLA_ROPE(__nv_bfloat16, KVLinearBuffer);
int const num_contexts, int64_t const* cu_ctx_cached_kv_lens, int const max_input_seq_len, \
int const lora_size, int const rope_size, float const* kv_scale_quant_orig_ptr, cudaStream_t stream); \
template void invokeMLARopeAppendPagedKVAssignQ<T, TCache>(KVBlockArray & kv_cache, T * q_ptr, \
T const* latent_cache_ptr, int const num_requests, int64_t const* cu_ctx_cached_kv_lens, \
T * latent_cache_ptr, int const num_requests, int64_t const* cu_ctx_cached_kv_lens, \
int64_t const* cu_seq_lens, int const max_input_uncached_seq_len, float2 const* cos_sin_cache, \
size_t head_num, int nope_size, int rope_size, int lora_size, float const* kv_scale_orig_quant_ptr, \
cudaStream_t stream);
Expand Down
8 changes: 4 additions & 4 deletions cpp/tensorrt_llm/kernels/mlaKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,10 @@ void invokeMLASetPagedKV(T* output, T const* k_ptr, T const* v_ptr, T const* k_p
int kv_cache_tokens_per_block, int64_t kv_token_stride, cudaStream_t stream);

template <typename T, typename TCache>
void invokeMLARopeAppendPagedKVAssignQ(KVBlockArray& kv_cache, T* q_ptr, T const* latent_cache_ptr,
int const num_requests, int64_t const* cu_ctx_cached_kv_lens, int64_t const* cu_seq_lens,
int const max_input_uncached_seq_len, float2 const* cos_sin_cache, size_t head_num, int nope_size, int rope_size,
int lora_size, float const* kv_scale_orig_quant_ptr, cudaStream_t stream);
void invokeMLARopeAppendPagedKVAssignQ(KVBlockArray& kv_cache, T* q_ptr, T* latent_cache_ptr, int const num_requests,
int64_t const* cu_ctx_cached_kv_lens, int64_t const* cu_seq_lens, int const max_input_uncached_seq_len,
float2 const* cos_sin_cache, size_t head_num, int nope_size, int rope_size, int lora_size,
float const* kv_scale_orig_quant_ptr, cudaStream_t stream);

} // namespace kernels
} // namespace tensorrt_llm
18 changes: 13 additions & 5 deletions cpp/tensorrt_llm/thop/attentionOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ class RunnerBase
torch::optional<torch::Tensor> q_pe, torch::optional<torch::Tensor> block_ids_per_seq,
torch::optional<torch::Tensor> mrope_rotary_cos_sin, torch::optional<torch::Tensor> mrope_position_deltas,
torch::optional<torch::Tensor> mla_context_paged_kv,
torch::optional<torch::Tensor> mla_context_kv_cache_block_offsets) const
torch::optional<torch::Tensor> mla_context_kv_cache_block_offsets,
torch::optional<torch::Tensor> softmax_stats_tensor) const
= 0;
};

Expand Down Expand Up @@ -127,7 +128,8 @@ class Runner : public RunnerBase
torch::optional<torch::Tensor> q_pe, torch::optional<torch::Tensor> block_ids_per_seq,
torch::optional<torch::Tensor> mrope_rotary_cos_sin, torch::optional<torch::Tensor> mrope_position_deltas,
torch::optional<torch::Tensor> mla_context_paged_kv,
torch::optional<torch::Tensor> mla_context_kv_cache_block_offsets) const override
torch::optional<torch::Tensor> mla_context_kv_cache_block_offsets,
torch::optional<torch::Tensor> softmax_stats_tensor) const override
{
auto stream = at::cuda::getCurrentCUDAStream(qkv.get_device());
T* attention_input = static_cast<T*>(qkv.slice(0, token_offset).data_ptr());
Expand Down Expand Up @@ -279,6 +281,11 @@ class Runner : public RunnerBase
AttentionOp::EnqueueContextParams<T> enqueue_params{common_enqueue_params};
enqueue_params.host_block_offsets = host_block_offsets;
enqueue_params.batch_size = num_seqs;
if (softmax_stats_tensor.has_value())
{
enqueue_params.softmaxStatsPtr = static_cast<float2*>(softmax_stats_tensor.value().data_ptr());
}

if (op.isMLAEnabled())
{
mla_params.cache_seq_lens = sequence_lengths_ptr;
Expand Down Expand Up @@ -385,7 +392,7 @@ void attention_inplace(torch::Tensor q, torch::optional<torch::Tensor> k, torch:
std::optional<int64_t> qk_rope_head_dim, std::optional<int64_t> v_head_dim,
torch::optional<torch::Tensor> mrope_rotary_cos_sin, torch::optional<torch::Tensor> mrope_position_deltas,
std::optional<torch::Tensor> mla_context_paged_kv, std::optional<torch::Tensor> mla_context_kv_cache_block_offsets,
std::optional<int64_t> attention_chunk_size)
std::optional<int64_t> attention_chunk_size, std::optional<torch::Tensor> softmax_stats_tensor)
{
TLLM_LOG_TRACE("Attention op starts at layer %d", layer_idx);
// Use these tensors to infer if the attention is using KV cache
Expand Down Expand Up @@ -603,7 +610,7 @@ void attention_inplace(torch::Tensor q, torch::optional<torch::Tensor> k, torch:
host_kv_cache_block_offsets, host_kv_cache_pool_pointers, host_kv_cache_pool_mapping, cache_indirection,
kv_scale_orig_quant, kv_scale_quant_orig, out_scale, rotary_inv_freq, rotary_cos_sin, latent_cache, q_pe,
block_ids_per_seq, mrope_rotary_cos_sin, mrope_position_deltas, mla_context_paged_kv,
mla_context_kv_cache_block_offsets);
mla_context_kv_cache_block_offsets, softmax_stats_tensor);
}

if ((num_generations > 0) && (attn_input_type != AttentionInputType::ContextOnly))
Expand All @@ -619,7 +626,7 @@ void attention_inplace(torch::Tensor q, torch::optional<torch::Tensor> k, torch:
host_kv_cache_block_offsets, host_kv_cache_pool_pointers, host_kv_cache_pool_mapping, cache_indirection,
kv_scale_orig_quant, kv_scale_quant_orig, out_scale, rotary_inv_freq, rotary_cos_sin, latent_cache, q_pe,
block_ids_per_seq, mrope_rotary_cos_sin, mrope_position_deltas, mla_context_paged_kv,
mla_context_kv_cache_block_offsets);
mla_context_kv_cache_block_offsets, softmax_stats_tensor);
}

TLLM_LOG_TRACE("Attention op stops at layer %d", layer_idx);
Expand Down Expand Up @@ -742,6 +749,7 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m)
", Tensor? mla_context_paged_kv"
", Tensor? mla_context_kv_cache_block_offsets"
", int? attention_chunk_size"
", Tensor? softmax_stats_tensor"
") -> ()");

m.def("attention_supports_nvfp4_output", &torch_ext::attention_supports_nvfp4_output);
Expand Down
Loading