diff --git a/cpp/tensorrt_llm/common/attentionOp.cpp b/cpp/tensorrt_llm/common/attentionOp.cpp index fc171572071..aba735f8258 100644 --- a/cpp/tensorrt_llm/common/attentionOp.cpp +++ b/cpp/tensorrt_llm/common/attentionOp.cpp @@ -1713,6 +1713,7 @@ int AttentionOp::enqueueContext(EnqueueContextParams const& params, cudaStrea fmhaParams.oSfScalePtr = params.attention_output_sf_scale; fmhaParams.stream = stream; fmhaParams.forceFp32Acc = mFMHAForceFP32Acc; + fmhaParams.softmaxStatsPtr = params.softmaxStatsPtr; if (mAttentionChunkSize) { diff --git a/cpp/tensorrt_llm/common/attentionOp.h b/cpp/tensorrt_llm/common/attentionOp.h index 2300abf64b4..d19a9cbcc4e 100644 --- a/cpp/tensorrt_llm/common/attentionOp.h +++ b/cpp/tensorrt_llm/common/attentionOp.h @@ -124,6 +124,9 @@ class AttentionOp int32_t num_encoder_tokens = 0; kernels::MlaParams* mla_param = nullptr; + // For MLA chunked prefill + void* softmaxStatsPtr = nullptr; + std::string enqueueContextParamsToString() const { // variables from the params coming from the runtime @@ -173,6 +176,7 @@ class AttentionOp ss << "cross_kv_length: " << this->cross_kv_length << std::endl; ss << "encoder_input_lengths: " << this->encoder_input_lengths << std::endl; ss << "num_encoder_tokens: " << this->num_encoder_tokens << std::endl; + ss << "softmaxStatsPtr: " << this->softmaxStatsPtr << std::endl; return ss.str(); } }; diff --git a/cpp/tensorrt_llm/kernels/fmhaDispatcher.cpp b/cpp/tensorrt_llm/kernels/fmhaDispatcher.cpp index ca2a1b377c5..7eb6682ec7a 100644 --- a/cpp/tensorrt_llm/kernels/fmhaDispatcher.cpp +++ b/cpp/tensorrt_llm/kernels/fmhaDispatcher.cpp @@ -197,6 +197,8 @@ void FmhaDispatcher::run(MHARunnerParams runnerParams) // Set it to INT_MAX as the kv cache pageOffsets will ensure that there is no out-of-bounds access. tllmRunnerParams.mNumPagesInMemPool = INT_MAX; tllmRunnerParams.mSfStartTokenIdx = 0; + // For mla chunked prefill + tllmRunnerParams.softmaxStatsPtr = reinterpret_cast(runnerParams.softmaxStatsPtr); tllmRunnerParams.stream = runnerParams.stream; mTllmGenFMHARunner->run(tllmRunnerParams); } diff --git a/cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cu b/cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cu new file mode 100644 index 00000000000..5931403c351 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cu @@ -0,0 +1,383 @@ +/* + * Copyright (c) 2019-2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mlaChunkedPrefill.cuh" +#include "tensorrt_llm/common/assert.h" +#include "tensorrt_llm/common/mathUtils.h" +#include +#include + +namespace +{ + +template +struct MergeSoftmaxTraits +{ + static constexpr int kQKNopeSize = 128; + static constexpr int kHeadSize = kQKNopeSize; + + static constexpr int kBytesPerElem = sizeof(T); + static constexpr int kBytesPerLoad = 16; + static constexpr int kElemPerThread = kBytesPerLoad / sizeof(T); + static_assert((kHeadSize * kBytesPerElem) % kBytesPerLoad == 0, + "kHeadSize * kBytesPerElem must be multiple of kBytesPerLoad (16Bytes)"); + static constexpr int kVecPerHead = (kHeadSize * kBytesPerElem) / kBytesPerLoad; + static constexpr int kTokenPerBlock + = std::is_same_v ? 4 : 8; // for each block, we fetch 8 token for fp16, 4 tokens for fp32. + static constexpr int kNumThreads = kVecPerHead * kTokenPerBlock; + + union VecReader + { + cutlass::Array data; + uint4 reader; + static_assert( + sizeof(uint4) == sizeof(cutlass::Array), "Size mismatch for MergeSoftmaxTraits"); + }; +}; + +template +struct loadChunkedKVKernelTraits +{ + static constexpr int kLoraSize = 512; + static constexpr int kRopeSize = 64; + static constexpr int kHeadSize = kLoraSize + kRopeSize; + using VecT = uint4; + static constexpr int kBytesPerElem = sizeof(T); + static constexpr int kBytesPerLoad = 16; + static constexpr int kElemPerLoad = kBytesPerLoad / kBytesPerElem; + static_assert((kHeadSize * kBytesPerElem) % kBytesPerLoad == 0, + "kHeadSize * kBytesPerElem must be multiple of kBytesPerLoad (16Bytes)"); + static constexpr int kVecPerHead = (kHeadSize * kBytesPerElem) / kBytesPerLoad; + static constexpr int kThreadPerHead = kVecPerHead; // for each head, we use kThreadPerHead threads to fetch all the + // kv cache data, each thread read kv cache only once. + static constexpr int kTokenPerBlock + = std::is_same_v ? 4 : 8; // for each block, we fetch 8 token for fp16, 4 tokens for fp32. + static constexpr int kBlockSize = kThreadPerHead * kTokenPerBlock; + static constexpr int kKVThreadPerHead = (kLoraSize * kBytesPerElem) / kBytesPerLoad; +}; + +template +struct setChunkedKVKernelTraits +{ + using VecT = uint4; + static constexpr int kQKNopeSize = 128; + static constexpr int kVHeadSize = 128; + static_assert(kQKNopeSize == kVHeadSize); + static constexpr int kRopeSize = 64; + static constexpr int kHeadSize = kQKNopeSize + kRopeSize; + static constexpr int kBytesPerElem = sizeof(T); + static constexpr int kBytesPerLoad = 16; + static constexpr int kElemPerLoad = kBytesPerLoad / kBytesPerElem; + static_assert((kHeadSize * kBytesPerElem) % kBytesPerLoad == 0, + "kHeadSize * kBytesPerElem must be multiple of kBytesPerLoad (16Bytes)"); + static constexpr int kThreadPerHead = (kHeadSize * kBytesPerElem) / kBytesPerLoad; + static constexpr int kKVThreadPerHead = (kQKNopeSize * kBytesPerElem) / kBytesPerLoad; + static constexpr int kCpTokenPerBlock = 16; + static constexpr int kBlockSize = kThreadPerHead * kCpTokenPerBlock; +}; + +// merged_attn [q_total_len, H=128, D=128] (T) +// merged_softmax_sum [q_total_len, H, 2] (float, max/sum) +template +__global__ void mergeAttnWithSoftmaxKernel(T* merged_attn, float2* merged_softmax_stats, T const* pre_attn, + float2 const* pre_softmax_stats, T const* curr_attn, float2 const* curr_softmax_stats, int64_t const* cu_q_seq_len, + int64_t const* merge_op, int const num_heads, int const head_size) +{ + using KT = MergeSoftmaxTraits; + int const batch_idx = static_cast(blockIdx.y); + int const head_idx = static_cast(blockIdx.z); + + int64_t merge_op_val = merge_op[batch_idx]; + if (merge_op_val == 0) + { + return; // skip this batch + } + + size_t const head_dim_vec_idx = (threadIdx.x % KT::kVecPerHead); + size_t const head_dim_idx = head_dim_vec_idx * KT::kElemPerThread; + + if (merge_op_val == 0) + { + return; // skip this batch + } + int const curr_q_len = static_cast(cu_q_seq_len[batch_idx + 1] - cu_q_seq_len[batch_idx]); + int const global_q_offset = cu_q_seq_len[batch_idx]; + + for (int local_token_idx = (threadIdx.x / KT::kVecPerHead) + blockIdx.x * KT::kTokenPerBlock; + local_token_idx < curr_q_len; local_token_idx += gridDim.x * KT::kTokenPerBlock) + { + // load softmax stat + int const global_softmax_stats_offset = (global_q_offset + local_token_idx) * num_heads + head_idx; + float2 curr_stats = curr_softmax_stats[global_softmax_stats_offset]; + // hack, current softmax stats max is not multiplied by bmm1_scale + // TODO: delete this line when trtllm gen kernel return the right max value. + curr_stats.x *= 0.072168784; // 1 / sqrt(128 + 64), head_size is 128 for output, but for bmm1 is 192 + float2 pre_stats = pre_softmax_stats[global_softmax_stats_offset]; + + // load attn + typename KT::VecReader pre_attn_reader{}; + typename KT::VecReader curr_attn_reader{}; + typename KT::VecReader merged_attn_reader{}; + + int const global_attn_offset + = (global_q_offset + local_token_idx) * num_heads * head_size + head_idx * head_size; + + pre_attn_reader.reader + = *reinterpret_cast(pre_attn + global_attn_offset + head_dim_idx); + curr_attn_reader.reader = *reinterpret_cast( + curr_attn + global_attn_offset + head_dim_idx); + + // only copy curr attn and curr softmax sum + if (merge_op_val == 2) + { + *reinterpret_cast(merged_attn + global_attn_offset + head_dim_idx) + = curr_attn_reader.reader; + if (head_dim_idx == 0) + { + merged_softmax_stats[global_softmax_stats_offset] = curr_stats; + } + } + else + { + // merge attn and softmax stats + float2 merged_stats; + merged_stats.x = fmaxf(pre_stats.x, curr_stats.x); + float pre_shift = std::exp(pre_stats.x - merged_stats.x); + float curr_shift = std::exp(curr_stats.x - merged_stats.x); + merged_stats.y = (pre_stats.y * pre_shift + curr_stats.y * curr_shift); + for (int i = 0; i < KT::kElemPerThread; ++i) + { + merged_attn_reader.data[i] + = (static_cast(pre_attn_reader.data[i]) * pre_stats.y * pre_shift + + static_cast(curr_attn_reader.data[i]) * curr_stats.y * curr_shift) + / merged_stats.y; + } + // write merged attn back to global memory + *reinterpret_cast(merged_attn + global_attn_offset + head_dim_idx) + = merged_attn_reader.reader; + // write merged softmax stats back to global memory + if (head_dim_idx == 0) + { + merged_softmax_stats[global_softmax_stats_offset] = merged_stats; + } + } + } +} + +// kv_output {total_chunk_token=b*chunk_size, h=1, d_lora} +// k_pe_output {total_chunk_token, h=1, d_rope} +template +__global__ void loadChunkedKVCacheForMLAKernel(T* output_kv_ptr, T* output_k_pe_ptr, + tensorrt_llm::kernels::KVBlockArray const kv_cache, int64_t const* cu_ctx_chunked_len, int chunked_size, + int chunked_idx) +{ + using KT = loadChunkedKVKernelTraits; + int const batch_idx = static_cast(blockIdx.y); + [[maybe_unused]] int const head_idx = static_cast(blockIdx.z); // default 0 + + size_t const head_dim_vec_idx = (threadIdx.x % KT::kVecPerHead); + size_t const head_dim_idx = head_dim_vec_idx * KT::kElemPerLoad; + + int64_t const real_chunked_size = cu_ctx_chunked_len[batch_idx + 1] - cu_ctx_chunked_len[batch_idx]; + int64_t const global_st_offset = cu_ctx_chunked_len[batch_idx]; + if (real_chunked_size <= 0) + { + return; // no kv cache for this batch + } + bool const is_valid_kv = head_dim_vec_idx < KT::kKVThreadPerHead; + for (int local_token_idx = (threadIdx.x / KT::kThreadPerHead) + blockIdx.x * KT::kTokenPerBlock; + local_token_idx < real_chunked_size; local_token_idx += gridDim.x * KT::kTokenPerBlock) + { + int token_idx_in_kv_cache = (chunked_idx * chunked_size) + local_token_idx; + bool const valid_token = (local_token_idx < chunked_size); + if (valid_token) + { + auto* kvSrc = reinterpret_cast(kv_cache.getKBlockPtr(batch_idx, token_idx_in_kv_cache)); + // head_idx === 0 + auto kvBlockIdx + = kv_cache.getKVLocalIdx(token_idx_in_kv_cache, 0, KT::kVecPerHead, static_cast(head_dim_vec_idx)); + auto ld_data = (reinterpret_cast(kvSrc))[kvBlockIdx]; + if (is_valid_kv) + { + // kv_output {total_chunk_token, h=1, d} + int const global_st_idx + = global_st_offset * KT::kLoraSize + local_token_idx * KT::kLoraSize + head_dim_idx; + *reinterpret_cast(output_kv_ptr + global_st_idx) = ld_data; + } + else + { + // k_pe_output {total_chunk_token, h=1, d_rope} + int const global_st_idx = global_st_offset * KT::kRopeSize + local_token_idx * KT::kRopeSize + + (head_dim_idx - KT::kLoraSize); + *reinterpret_cast(output_k_pe_ptr + global_st_idx) = ld_data; + } + } + } +} + +// in the most of cases, chunk_size = max_seq_len +// output_kv {B, 2, ceil(max_seq_len / kv_cache_tokens_per_block), h, kv_cache_tokens_per_block, d}, padding with +// zero +// kv {token_size = B*chunked_unit_size, 2, H=128, uncompressed_h=128}, k_pe {token_size = B*chunked_unit_size, h=1, +// rope_h} +// cu_seq_lens {batch + 1}, fake cu_seq_len, for chunked prefill is {0, chunk_size, chunk_size * 2 ....} +template +__global__ void setChunkedKVCacheForMLAKernel(T* output_kv, T const* kv, T const* k_pe, int const max_seq_len, + int const num_heads, int uncompressed_head_size, int rope_size, int64_t const* cu_seq_lens, + int kv_cache_tokens_per_block) +{ + using KT = setChunkedKVKernelTraits; + int const batch_idx = static_cast(blockIdx.y); + int const head_idx = static_cast(blockIdx.z); + int const head_dim_vec_idx = (threadIdx.x % KT::kThreadPerHead); + int const head_dim_idx = head_dim_vec_idx * KT::kElemPerLoad; + bool const is_valid_kv = head_dim_idx < KT::kQKNopeSize; + + int64_t const global_token_offset = cu_seq_lens[batch_idx]; + int64_t const cache_kv_len = cu_seq_lens[batch_idx + 1] - cu_seq_lens[batch_idx]; + int const kv_cache_block_num = (max_seq_len + kv_cache_tokens_per_block - 1) / kv_cache_tokens_per_block; + int const kv_cache_block_size = num_heads * kv_cache_tokens_per_block * (uncompressed_head_size + rope_size); + int64_t const offset_for_kv_in_mem_pool = kv_cache_block_num * kv_cache_block_size; + int64_t const kv_offset = num_heads * uncompressed_head_size; + size_t const seq_len_loop_end = cache_kv_len; + for (int local_token_idx = (threadIdx.x / KT::kThreadPerHead) + blockIdx.x * KT::kCpTokenPerBlock; + local_token_idx < seq_len_loop_end; local_token_idx += gridDim.x * KT::kCpTokenPerBlock) + { + if (local_token_idx >= cache_kv_len) + { + break; + } + if (is_valid_kv) + { + + int64_t ld_kv_global_offset + = int64_t(global_token_offset + local_token_idx) * 2 * num_heads * uncompressed_head_size + + head_idx * uncompressed_head_size; + int64_t ld_kv_local_offset = head_dim_vec_idx; + auto k_data = (reinterpret_cast(kv + ld_kv_global_offset))[ld_kv_local_offset]; + auto v_data = (reinterpret_cast( + kv + kv_offset + ld_kv_global_offset))[ld_kv_local_offset]; + + int64_t st_k_global_offset = int64_t(batch_idx) * 2 * offset_for_kv_in_mem_pool + + local_token_idx / kv_cache_tokens_per_block * kv_cache_block_size + + head_idx * kv_cache_tokens_per_block * (uncompressed_head_size + rope_size) + + (local_token_idx % kv_cache_tokens_per_block) * (uncompressed_head_size + rope_size); + int64_t st_v_global_offset = st_k_global_offset + offset_for_kv_in_mem_pool; + int64_t st_k_local_offset = head_dim_vec_idx; + int64_t st_v_local_offset = head_dim_vec_idx; + (reinterpret_cast(output_kv + st_k_global_offset))[st_k_local_offset] = k_data; + (reinterpret_cast(output_kv + st_v_global_offset))[st_v_local_offset] = v_data; + } + else + { + // rope h = 1 + int64_t ld_rope_global_offset = int64_t(global_token_offset + local_token_idx) * rope_size; + int64_t ld_rope_local_offset = head_dim_vec_idx - KT::kKVThreadPerHead; + auto rope_data + = (reinterpret_cast(k_pe + ld_rope_global_offset))[ld_rope_local_offset]; + int64_t st_rope_global_offset = int64_t(batch_idx) * 2 * offset_for_kv_in_mem_pool + + local_token_idx / kv_cache_tokens_per_block * kv_cache_block_size + + head_idx * kv_cache_tokens_per_block * (uncompressed_head_size + rope_size) + + (local_token_idx % kv_cache_tokens_per_block) * (uncompressed_head_size + rope_size); + int64_t st_rope_local_offset = head_dim_vec_idx; + (reinterpret_cast(output_kv + st_rope_global_offset))[st_rope_local_offset] = rope_data; + } + } +} + +} // namespace + +namespace tensorrt_llm +{ +namespace kernels +{ + +// merged_attn [q_total_len, H=128, D=128] (T) +// merged_softmax_sum [q_total_len, H, 2] (float), the first part is the max value for each +// row of P = QK^T, the second part is the softmax sum +// if merge_op[b] == 0, we just skip this batch, if merge_op[b] == 1, we merge the pre-attn and curr-attn, if +// merge_op[b] +// == 2, we only copy curr_attn and curr_softmax_sum to merged_attn and merged_softmax_sum +template +void invokeMergeAttnWithSoftmax(T* merged_attn, float* merged_softmax_stats, T const* pre_attn, + float const* pre_softmax_stats, T const* curr_attn, float const* curr_softmax_stats, int const batch_size, + int64_t const* cu_q_seq_len, int max_q_seq_len, int64_t const* merge_op, int const num_heads, int const head_size, + cudaStream_t stream) +{ + using KT = MergeSoftmaxTraits; + TLLM_CHECK_WITH_INFO(head_size == KT::kHeadSize, "head dim should be equal to %d", KT::kHeadSize); + + dim3 grid(static_cast(tensorrt_llm::common::divUp(max_q_seq_len, KT::kTokenPerBlock)), batch_size, num_heads); + dim3 block(KT::kNumThreads); + + mergeAttnWithSoftmaxKernel<<>>(merged_attn, + reinterpret_cast(merged_softmax_stats), pre_attn, reinterpret_cast(pre_softmax_stats), + curr_attn, reinterpret_cast(curr_softmax_stats), cu_q_seq_len, merge_op, num_heads, head_size); +} + +// load single chunk kv from kv_cache for each request +template +void invokeMLALoadChunkedKV(T* output_kv_ptr, T* output_k_pe_ptr, KVBlockArray const& kv_cache, int const num_contexts, + int64_t const* cu_ctx_chunked_len, int lora_size, int rope_size, int chunked_size, int chunked_idx, + cudaStream_t stream) +{ + using KT = loadChunkedKVKernelTraits; + TLLM_CHECK_WITH_INFO(lora_size + rope_size == KT::kHeadSize, "head dim should be equal to %d", KT::kHeadSize); + TLLM_CHECK_WITH_INFO(lora_size == KT::kLoraSize, "lora dim should be equal to %d", KT::kLoraSize); + TLLM_CHECK_WITH_INFO(rope_size == KT::kRopeSize, "rope dim should be equal to %d", KT::kRopeSize); + // {chunked_unit_size / token_per_block, batch_size, head_num} + dim3 grid(static_cast(tensorrt_llm::common::divUp(chunked_size, KT::kTokenPerBlock)), num_contexts, 1); + loadChunkedKVCacheForMLAKernel<<>>( + output_kv_ptr, output_k_pe_ptr, kv_cache, cu_ctx_chunked_len, chunked_size, chunked_idx); +} + +// output_kv {B, 2, ceil(chunked_size / kv_cache_tokens_per_block), h, kv_cache_tokens_per_block, d}, padding with +// zero +// kv {total_token, 2, H, uncompressed_h=128} 0 for k and 1 for v, k_pe {total_token, h=1, rope_h} +// input kv and k_pe can be cached tokens or uncached tokens +template +void invokeMLASetChunkedKV(T* output_kv, T const* kv, T const* k_pe, int const batch_size, int const max_seq_len, + int const num_heads, int uncompressed_head_size, int rope_size, int64_t const* cu_seq_lens, + int const kv_cache_tokens_per_block, cudaStream_t stream) +{ + using KT = setChunkedKVKernelTraits; + TLLM_CHECK_WITH_INFO( + uncompressed_head_size + rope_size == KT::kHeadSize, "head dim should be equal to %d", KT::kHeadSize); + TLLM_CHECK_WITH_INFO(kv_cache_tokens_per_block % KT::kCpTokenPerBlock == 0, + "kv_cache_tokens_per_block should be multiple of %d", KT::kCpTokenPerBlock); + + dim3 grid(tensorrt_llm::common::divUp(max_seq_len, KT::kCpTokenPerBlock), batch_size, num_heads); + setChunkedKVCacheForMLAKernel<<>>(output_kv, kv, k_pe, max_seq_len, num_heads, + uncompressed_head_size, rope_size, cu_seq_lens, kv_cache_tokens_per_block); +} + +#define INSTANTIATE_MLA_CHUNKED_PREFILL_KERNEL(T) \ + template void invokeMergeAttnWithSoftmax(T * merged_attn, float* merged_softmax_stats, T const* pre_attn, \ + float const* pre_softmax_stats, T const* curr_attn, float const* curr_softmax_stats, int const batch_size, \ + int64_t const* cu_q_seq_len, int max_q_seq_len, int64_t const* merge_op, int const num_heads, \ + int const head_size, cudaStream_t stream); \ + template void invokeMLALoadChunkedKV(T * output_kv_ptr, T * output_k_pe_ptr, KVBlockArray const& kv_cache, \ + int const num_contexts, int64_t const* cu_ctx_chunked_len, int lora_size, int rope_size, int chunked_size, \ + int chunked_idx, cudaStream_t stream); \ + template void invokeMLASetChunkedKV(T * output_kv, T const* kv, T const* k_pe, int const batch_size, \ + int const max_seq_len, int const num_heads, int uncompressed_head_size, int rope_size, \ + int64_t const* cu_seq_lens, int const kv_cache_tokens_per_block, cudaStream_t stream); + +INSTANTIATE_MLA_CHUNKED_PREFILL_KERNEL(half); +INSTANTIATE_MLA_CHUNKED_PREFILL_KERNEL(float); +INSTANTIATE_MLA_CHUNKED_PREFILL_KERNEL(__nv_bfloat16); +} // namespace kernels +} // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cuh b/cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cuh new file mode 100644 index 00000000000..0b30390b400 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cuh @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2019-2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/kvCacheUtils.h" + +namespace tensorrt_llm +{ +namespace kernels +{ +// merged_attn [q_total_len, H=128, D=128] (T) +// merged_softmax_sum [q_total_len, H, 2] (float), the first part is the max value for each +// row of P = QK^T, the second part is the softmax sum +// if merge_op[b] == 0, we just skip this batch, if merge_op[b] == 1, we merge the pre-attn and curr-attn, if +// merge_op[b] +// == 2, we only copy curr_attn and curr_softmax_sum to merged_attn and merged_softmax_sum +template +void invokeMergeAttnWithSoftmax(T* merged_attn, float* merged_softmax_stats, T const* pre_attn, + float const* pre_softmax_stats, T const* curr_attn, float const* curr_softmax_stats, int const batch_size, + int64_t const* cu_q_seq_len, int max_q_seq_len, int64_t const* merge_op, int const num_heads, int const head_size, + cudaStream_t stream); + +// load single chunk kv from kv_cache for each request +template +void invokeMLALoadChunkedKV(T* output_kv_ptr, T* output_k_pe_ptr, KVBlockArray const& kv_cache, int const num_contexts, + int64_t const* cu_ctx_chunked_len, int lora_size, int rope_size, int chunked_size, int chunked_idx, + cudaStream_t stream); + +// output_kv {B, 2, ceil(chunked_size / kv_cache_tokens_per_block), h, kv_cache_tokens_per_block, d}, padding with +// zero +// kv {total_token, 2, H, uncompressed_h=128} 0 for k and 1 for v, k_pe {total_token, h=1, rope_h} +// input kv and k_pe can be cached tokens or uncached tokens +template +void invokeMLASetChunkedKV(T* output_kv, T const* kv, T const* k_pe, int const batch_size, int const max_seq_len, + int const num_heads, int uncompressed_head_size, int rope_size, int64_t const* cu_seq_lens, + int const kv_cache_tokens_per_block, cudaStream_t stream); +} // namespace kernels +} // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/mlaKernels.cu b/cpp/tensorrt_llm/kernels/mlaKernels.cu index ffd0c51ec4e..8a9dcc83756 100644 --- a/cpp/tensorrt_llm/kernels/mlaKernels.cu +++ b/cpp/tensorrt_llm/kernels/mlaKernels.cu @@ -761,7 +761,7 @@ __global__ void setPagedKVCacheForMLAKernel(T* output, T const* k_ptr, T const* // q {total_uncached_tokens, h, d_nope + d_rope} // latent_cache {total_uncached_tokens, d_k + d_rope} template -__global__ void applyMLARopeAppendPagedKVAssignQKernel(KVBlockArray kv_cache, T* q_ptr, T const* latent_cache_ptr, +__global__ void applyMLARopeAppendPagedKVAssignQKernel(KVBlockArray kv_cache, T* q_ptr, T* latent_cache_ptr, int64_t const* cu_ctx_cached_kv_lens, int64_t const* cu_seq_lens, int const max_input_uncached_seq_len, float2 const* cos_sin_cache, size_t head_num, int nope_size, float const* kv_scale_orig_quant_ptr) { @@ -851,6 +851,10 @@ __global__ void applyMLARopeAppendPagedKVAssignQKernel(KVBlockArray kv_cache, T* if constexpr (std::is_same_v) { reinterpret_cast(kDst)[inBlockIdx] = data; + // copy to latent_cache (for chunked prefill, it will not load kv cache for uncached k_pe) + auto const src_k_global_offset + = static_cast(global_token_idx) * (K_DIM + ROPE_DIM) + K_DIM; + *reinterpret_cast(&latent_cache_ptr[src_k_global_offset + head_dim_idx]) = data; } else if constexpr (std::is_same_v) { @@ -980,10 +984,10 @@ void invokeMLASetPagedKV(T* output, T const* k_ptr, T const* v_ptr, T const* k_p } template -void invokeMLARopeAppendPagedKVAssignQ(KVBlockArray& kv_cache, T* q_ptr, T const* latent_cache_ptr, - int const num_requests, int64_t const* cu_ctx_cached_kv_lens, int64_t const* cu_seq_lens, - int const max_input_uncached_seq_len, float2 const* cos_sin_cache, size_t head_num, int nope_size, int rope_size, - int lora_size, float const* kv_scale_orig_quant_ptr, cudaStream_t stream) +void invokeMLARopeAppendPagedKVAssignQ(KVBlockArray& kv_cache, T* q_ptr, T* latent_cache_ptr, int const num_requests, + int64_t const* cu_ctx_cached_kv_lens, int64_t const* cu_seq_lens, int const max_input_uncached_seq_len, + float2 const* cos_sin_cache, size_t head_num, int nope_size, int rope_size, int lora_size, + float const* kv_scale_orig_quant_ptr, cudaStream_t stream) { dim3 grid(int(tensorrt_llm::common::divUp(max_input_uncached_seq_len, 32)), num_requests, head_num + 1 + 8); TLLM_CHECK_WITH_INFO(lora_size == 512, "lora_size should be equal to %d", 512); @@ -1012,7 +1016,7 @@ INSTANTIATE_MLA_ROPE(__nv_bfloat16, KVLinearBuffer); int const num_contexts, int64_t const* cu_ctx_cached_kv_lens, int const max_input_seq_len, \ int const lora_size, int const rope_size, float const* kv_scale_quant_orig_ptr, cudaStream_t stream); \ template void invokeMLARopeAppendPagedKVAssignQ(KVBlockArray & kv_cache, T * q_ptr, \ - T const* latent_cache_ptr, int const num_requests, int64_t const* cu_ctx_cached_kv_lens, \ + T * latent_cache_ptr, int const num_requests, int64_t const* cu_ctx_cached_kv_lens, \ int64_t const* cu_seq_lens, int const max_input_uncached_seq_len, float2 const* cos_sin_cache, \ size_t head_num, int nope_size, int rope_size, int lora_size, float const* kv_scale_orig_quant_ptr, \ cudaStream_t stream); diff --git a/cpp/tensorrt_llm/kernels/mlaKernels.h b/cpp/tensorrt_llm/kernels/mlaKernels.h index 812df1b3742..3d5aa4f148d 100644 --- a/cpp/tensorrt_llm/kernels/mlaKernels.h +++ b/cpp/tensorrt_llm/kernels/mlaKernels.h @@ -106,10 +106,10 @@ void invokeMLASetPagedKV(T* output, T const* k_ptr, T const* v_ptr, T const* k_p int kv_cache_tokens_per_block, int64_t kv_token_stride, cudaStream_t stream); template -void invokeMLARopeAppendPagedKVAssignQ(KVBlockArray& kv_cache, T* q_ptr, T const* latent_cache_ptr, - int const num_requests, int64_t const* cu_ctx_cached_kv_lens, int64_t const* cu_seq_lens, - int const max_input_uncached_seq_len, float2 const* cos_sin_cache, size_t head_num, int nope_size, int rope_size, - int lora_size, float const* kv_scale_orig_quant_ptr, cudaStream_t stream); +void invokeMLARopeAppendPagedKVAssignQ(KVBlockArray& kv_cache, T* q_ptr, T* latent_cache_ptr, int const num_requests, + int64_t const* cu_ctx_cached_kv_lens, int64_t const* cu_seq_lens, int const max_input_uncached_seq_len, + float2 const* cos_sin_cache, size_t head_num, int nope_size, int rope_size, int lora_size, + float const* kv_scale_orig_quant_ptr, cudaStream_t stream); } // namespace kernels } // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/thop/attentionOp.cpp b/cpp/tensorrt_llm/thop/attentionOp.cpp index 4c64fec1a65..e17f43c0dc1 100644 --- a/cpp/tensorrt_llm/thop/attentionOp.cpp +++ b/cpp/tensorrt_llm/thop/attentionOp.cpp @@ -77,7 +77,8 @@ class RunnerBase torch::optional q_pe, torch::optional block_ids_per_seq, torch::optional mrope_rotary_cos_sin, torch::optional mrope_position_deltas, torch::optional mla_context_paged_kv, - torch::optional mla_context_kv_cache_block_offsets) const + torch::optional mla_context_kv_cache_block_offsets, + torch::optional softmax_stats_tensor) const = 0; }; @@ -127,7 +128,8 @@ class Runner : public RunnerBase torch::optional q_pe, torch::optional block_ids_per_seq, torch::optional mrope_rotary_cos_sin, torch::optional mrope_position_deltas, torch::optional mla_context_paged_kv, - torch::optional mla_context_kv_cache_block_offsets) const override + torch::optional mla_context_kv_cache_block_offsets, + torch::optional softmax_stats_tensor) const override { auto stream = at::cuda::getCurrentCUDAStream(qkv.get_device()); T* attention_input = static_cast(qkv.slice(0, token_offset).data_ptr()); @@ -279,6 +281,11 @@ class Runner : public RunnerBase AttentionOp::EnqueueContextParams enqueue_params{common_enqueue_params}; enqueue_params.host_block_offsets = host_block_offsets; enqueue_params.batch_size = num_seqs; + if (softmax_stats_tensor.has_value()) + { + enqueue_params.softmaxStatsPtr = static_cast(softmax_stats_tensor.value().data_ptr()); + } + if (op.isMLAEnabled()) { mla_params.cache_seq_lens = sequence_lengths_ptr; @@ -385,7 +392,7 @@ void attention_inplace(torch::Tensor q, torch::optional k, torch: std::optional qk_rope_head_dim, std::optional v_head_dim, torch::optional mrope_rotary_cos_sin, torch::optional mrope_position_deltas, std::optional mla_context_paged_kv, std::optional mla_context_kv_cache_block_offsets, - std::optional attention_chunk_size) + std::optional attention_chunk_size, std::optional softmax_stats_tensor) { TLLM_LOG_TRACE("Attention op starts at layer %d", layer_idx); // Use these tensors to infer if the attention is using KV cache @@ -603,7 +610,7 @@ void attention_inplace(torch::Tensor q, torch::optional k, torch: host_kv_cache_block_offsets, host_kv_cache_pool_pointers, host_kv_cache_pool_mapping, cache_indirection, kv_scale_orig_quant, kv_scale_quant_orig, out_scale, rotary_inv_freq, rotary_cos_sin, latent_cache, q_pe, block_ids_per_seq, mrope_rotary_cos_sin, mrope_position_deltas, mla_context_paged_kv, - mla_context_kv_cache_block_offsets); + mla_context_kv_cache_block_offsets, softmax_stats_tensor); } if ((num_generations > 0) && (attn_input_type != AttentionInputType::ContextOnly)) @@ -619,7 +626,7 @@ void attention_inplace(torch::Tensor q, torch::optional k, torch: host_kv_cache_block_offsets, host_kv_cache_pool_pointers, host_kv_cache_pool_mapping, cache_indirection, kv_scale_orig_quant, kv_scale_quant_orig, out_scale, rotary_inv_freq, rotary_cos_sin, latent_cache, q_pe, block_ids_per_seq, mrope_rotary_cos_sin, mrope_position_deltas, mla_context_paged_kv, - mla_context_kv_cache_block_offsets); + mla_context_kv_cache_block_offsets, softmax_stats_tensor); } TLLM_LOG_TRACE("Attention op stops at layer %d", layer_idx); @@ -742,6 +749,7 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m) ", Tensor? mla_context_paged_kv" ", Tensor? mla_context_kv_cache_block_offsets" ", int? attention_chunk_size" + ", Tensor? softmax_stats_tensor" ") -> ()"); m.def("attention_supports_nvfp4_output", &torch_ext::attention_supports_nvfp4_output); diff --git a/cpp/tensorrt_llm/thop/mlaPreprocessOp.cpp b/cpp/tensorrt_llm/thop/mlaPreprocessOp.cpp index 30acb0cefb9..60196e16388 100644 --- a/cpp/tensorrt_llm/thop/mlaPreprocessOp.cpp +++ b/cpp/tensorrt_llm/thop/mlaPreprocessOp.cpp @@ -18,6 +18,7 @@ #include "tensorrt_llm/common/cudaUtils.h" #include "tensorrt_llm/common/quantization.h" #include "tensorrt_llm/kernels/kvCacheUtils.h" +#include "tensorrt_llm/kernels/mlaChunkedPrefill.cuh" #include "tensorrt_llm/kernels/mlaKernels.h" #include "tensorrt_llm/thop/thUtils.h" #include @@ -47,6 +48,19 @@ void loadPagedKVCacheForMLAHelper(torch::Tensor& compressed_kv, torch::Tensor& k cu_ctx_cached_kv_lens_ptr, max_input_seq_len, lora_size, rope_size, kv_scale_quant_orig_ptr, stream); } +template +void loadChunkedKVCacheForMLAHelper(torch::Tensor& output_kv, torch::Tensor& output_k_pe, KVBlockArray& kv_cache, + int const num_contexts, torch::Tensor const& cu_ctx_chunked_len, int lora_size, int rope_size, + int const chunked_size, int const chunked_idx) +{ + auto stream = at::cuda::getCurrentCUDAStream(output_kv.get_device()); + + T* output_kv_ptr = static_cast(output_kv.data_ptr()); + T* output_k_pe_ptr = static_cast(output_k_pe.data_ptr()); + tensorrt_llm::kernels::invokeMLALoadChunkedKV(output_kv_ptr, output_k_pe_ptr, kv_cache, num_contexts, + cu_ctx_chunked_len.data_ptr(), lora_size, rope_size, chunked_size, chunked_idx, stream); +} + template void setPagedKVCacheForMLAHelper(torch::Tensor& output, torch::Tensor const& k, torch::Tensor const& v, torch::Tensor const& k_pe, int const num_requests, torch::Tensor const& cu_seq_lens, int const max_input_seq_len, @@ -65,16 +79,30 @@ void setPagedKVCacheForMLAHelper(torch::Tensor& output, torch::Tensor const& k, max_input_seq_len, num_heads, kv_dim, rope_dim, kv_cache_tokens_per_block, kv_token_stride, stream); } +template +void setChunkedKVCacheForMLAHelper(torch::Tensor& output, torch::Tensor const& kv, torch::Tensor const& k_pe, + int const num_requests, torch::Tensor const& cu_seq_lens, int num_heads, int kv_dim, int rope_dim, + int kv_cache_tokens_per_block, int max_seq_len) +{ + auto stream = at::cuda::getCurrentCUDAStream(output.get_device()); + T* output_ptr = static_cast(output.data_ptr()); + T* kv_ptr = static_cast(kv.data_ptr()); + T* k_pe_ptr = static_cast(k_pe.data_ptr()); + auto* cu_seq_lens_ptr = cu_seq_lens.data_ptr(); + + tensorrt_llm::kernels::invokeMLASetChunkedKV(output_ptr, kv_ptr, k_pe_ptr, num_requests, max_seq_len, num_heads, + kv_dim, rope_dim, cu_seq_lens_ptr, kv_cache_tokens_per_block, stream); +} + template -void invokeMLARopeAppendPagedKVAssignQHelper(KVBlockArray& kv_cache, torch::Tensor& q, - torch::Tensor const& latent_cache, int const num_requests, torch::Tensor const& cu_ctx_cached_kv_lens, - torch::Tensor const& cu_seq_lens, int const max_input_uncached_seq_len, torch::Tensor const& cos_sin_cache, - int const head_num, int const nope_size, int const rope_size, int const lora_size, - float const* kv_scale_orig_quant_ptr) +void invokeMLARopeAppendPagedKVAssignQHelper(KVBlockArray& kv_cache, torch::Tensor& q, torch::Tensor& latent_cache, + int const num_requests, torch::Tensor const& cu_ctx_cached_kv_lens, torch::Tensor const& cu_seq_lens, + int const max_input_uncached_seq_len, torch::Tensor const& cos_sin_cache, int const head_num, int const nope_size, + int const rope_size, int const lora_size, float const* kv_scale_orig_quant_ptr) { auto stream = at::cuda::getCurrentCUDAStream(q.get_device()); auto* q_ptr = static_cast(q.data_ptr()); - auto const* latent_cache_ptr = static_cast(latent_cache.data_ptr()); + auto* latent_cache_ptr = static_cast(latent_cache.data_ptr()); auto const* cu_ctx_cached_kv_lens_ptr = cu_ctx_cached_kv_lens.data_ptr(); auto const* cu_seq_lens_ptr = cu_seq_lens.data_ptr(); auto const* cos_sin_cache_ptr = static_cast(cos_sin_cache.data_ptr()); @@ -83,6 +111,25 @@ void invokeMLARopeAppendPagedKVAssignQHelper(KVBlockArray& kv_cache, torch::Tens rope_size, lora_size, kv_scale_orig_quant_ptr, stream); } +template +void mergeChunkedAttentionForMLAHelper(torch::Tensor& merged_attn, torch::Tensor const& temp_attn, + torch::Tensor& merged_softmax_stats, torch::Tensor const& temp_softmax_stats, int64_t const num_requests, + torch::Tensor const& cu_q_seq_lens, int64_t const max_q_seq_len, torch::Tensor const& merge_op, + int64_t const num_heads, int64_t const head_size) +{ + auto stream = at::cuda::getCurrentCUDAStream(merged_attn.get_device()); + T* merged_attn_ptr = static_cast(merged_attn.data_ptr()); + T* temp_attn_ptr = static_cast(temp_attn.data_ptr()); + float* merged_softmax_stats_ptr = static_cast(merged_softmax_stats.data_ptr()); + float* temp_softmax_stats_ptr = static_cast(temp_softmax_stats.data_ptr()); + int64_t* const cu_q_seq_lens_ptr = cu_q_seq_lens.data_ptr(); + int64_t* const merge_op_ptr = merge_op.data_ptr(); + + tensorrt_llm::kernels::invokeMergeAttnWithSoftmax(merged_attn_ptr, merged_softmax_stats_ptr, merged_attn_ptr, + merged_softmax_stats_ptr, temp_attn_ptr, temp_softmax_stats_ptr, num_requests, cu_q_seq_lens_ptr, max_q_seq_len, + merge_op_ptr, num_heads, head_size, stream); +} + /** * Creates a KVBlockArray object for managing KV cache * @@ -233,6 +280,70 @@ std::vector loadPagedKVCacheForMLA(torch::ScalarType out_dtype, i return outputs; } +std::vector loadChunkedKVCacheForMLA(torch::ScalarType out_dtype, int64_t const num_contexts, + int64_t const num_ctx_cached_tokens, torch::Tensor& cu_ctx_chunked_kv_lens, + torch::Tensor const& kv_cache_block_offsets, torch::Tensor const& host_kv_cache_pool_pointers, + torch::Tensor const& host_kv_cache_pool_mapping, torch::optional kv_scale_orig_quant, + torch::optional kv_scale_quant_orig, int64_t const layer_idx, int64_t const lora_size, + int64_t const rope_size, int64_t const tokens_per_block, int64_t const chunked_size, int64_t const chunked_index, + int64_t const attention_window_size, int64_t const sink_token_length, int64_t const beam_width, + int64_t const quant_mode) +{ + TORCH_CHECK(out_dtype == torch::kFloat16 || out_dtype == torch::kFloat32 || out_dtype == torch::kBFloat16, + "out_dtype only support float16, float32, bfloat16"); + TLLM_CHECK(num_contexts > 0); + CHECK_INPUT(cu_ctx_chunked_kv_lens, torch::kInt64); + TORCH_CHECK(cu_ctx_chunked_kv_lens.dim() == 1); + TORCH_CHECK(cu_ctx_chunked_kv_lens.size(0) >= num_contexts + 1); + int head_size = lora_size + rope_size; + auto kv_cache_quant_mode = tc::QuantMode(static_cast(quant_mode)); + int max_blocks_per_sequence = kv_cache_block_offsets.size(-1); + KVBlockArray kv_cache_buffer + = createKVBlockArray(num_contexts, max_blocks_per_sequence, tokens_per_block, head_size, + 1, // num_kv_heads is always 1 for MLA + attention_window_size, sink_token_length, beam_width, kv_cache_quant_mode, out_dtype, + host_kv_cache_pool_pointers, host_kv_cache_pool_mapping, kv_cache_block_offsets, layer_idx); + + float const* kv_scale_orig_quant_ptr = nullptr; + float const* kv_scale_quant_orig_ptr = nullptr; + if (kv_cache_quant_mode.hasKvCacheQuant()) + { + TORCH_CHECK(kv_scale_orig_quant.has_value()); + TORCH_CHECK(kv_scale_quant_orig.has_value()); + kv_scale_orig_quant_ptr = kv_scale_orig_quant.value().data_ptr(); + kv_scale_quant_orig_ptr = kv_scale_quant_orig.value().data_ptr(); + TLLM_CHECK(kv_scale_orig_quant_ptr != nullptr); + TLLM_CHECK(kv_scale_quant_orig_ptr != nullptr); + } + + std::vector outputs; + + // compressed_kv {num_ctx_cached_tokens, lora_size} + outputs.push_back(torch::empty( + {num_ctx_cached_tokens, lora_size}, torch::dtype(out_dtype).device(torch::kCUDA).requires_grad(false))); + // k_pe {num_ctx_cached_tokens, rope_size} + outputs.push_back(torch::empty( + {num_ctx_cached_tokens, rope_size}, torch::dtype(out_dtype).device(torch::kCUDA).requires_grad(false))); + + if (out_dtype == torch::kFloat16) + { + loadChunkedKVCacheForMLAHelper(outputs[0], outputs[1], kv_cache_buffer, num_contexts, + cu_ctx_chunked_kv_lens, lora_size, rope_size, chunked_size, chunked_index); + } + else if (out_dtype == torch::kFloat32) + { + loadChunkedKVCacheForMLAHelper(outputs[0], outputs[1], kv_cache_buffer, num_contexts, + cu_ctx_chunked_kv_lens, lora_size, rope_size, chunked_size, chunked_index); + } + else if (out_dtype == torch::kBFloat16) + { + loadChunkedKVCacheForMLAHelper<__nv_bfloat16>(outputs[0], outputs[1], kv_cache_buffer, num_contexts, + cu_ctx_chunked_kv_lens, lora_size, rope_size, chunked_size, chunked_index); + } + + return outputs; +} + torch::Tensor setPagedKVCacheForMLA(torch::Tensor& output, torch::Tensor const& k, torch::Tensor const& v, torch::Tensor const& k_pe, int64_t const num_requests, torch::Tensor const& cu_seq_lens, int64_t const max_input_seq_len, int64_t const num_heads, int64_t const kv_dim, int64_t const rope_dim, @@ -296,7 +407,49 @@ torch::Tensor setPagedKVCacheForMLA(torch::Tensor& output, torch::Tensor const& return faked_kv_cache_block_offsets; } -void MLARopeAppendPagedKVAssignQ(torch::Tensor& q, torch::Tensor const& latent_cache, int64_t const num_contexts, +torch::Tensor setChunkedKVCacheForMLA(torch::Tensor& output, torch::Tensor const& kv, torch::Tensor const& k_pe, + int64_t const num_requests, torch::Tensor const& cu_seq_lens, int64_t const num_heads, int64_t const kv_dim, + int64_t const rope_dim, int64_t const kv_cache_tokens_per_block, int64_t const max_seq_len) +{ + TORCH_CHECK(output.numel() > 0); + TORCH_CHECK(output.scalar_type() == torch::kFloat16 || output.scalar_type() == torch::kFloat32 + || output.scalar_type() == torch::kBFloat16); + CHECK_TH_CUDA(output); + CHECK_CONTIGUOUS(output); + CHECK_INPUT(kv, output.scalar_type()); + CHECK_INPUT(k_pe, output.scalar_type()); + CHECK_INPUT(cu_seq_lens, torch::kInt64); + TORCH_CHECK(cu_seq_lens.dim() == 1); + TORCH_CHECK(cu_seq_lens.size(0) >= num_requests + 1); + + if (output.scalar_type() == torch::kFloat16) + { + setChunkedKVCacheForMLAHelper(output, kv, k_pe, num_requests, cu_seq_lens, num_heads, kv_dim, rope_dim, + kv_cache_tokens_per_block, max_seq_len); + } + else if (output.scalar_type() == torch::kFloat32) + { + setChunkedKVCacheForMLAHelper(output, kv, k_pe, num_requests, cu_seq_lens, num_heads, kv_dim, rope_dim, + kv_cache_tokens_per_block, max_seq_len); + } + else if (output.scalar_type() == torch::kBFloat16) + { + setChunkedKVCacheForMLAHelper<__nv_bfloat16>(output, kv, k_pe, num_requests, cu_seq_lens, num_heads, kv_dim, + rope_dim, kv_cache_tokens_per_block, max_seq_len); + } + + int64_t max_block_num = (max_seq_len + kv_cache_tokens_per_block - 1) / kv_cache_tokens_per_block; + + // TODO: actually this offset is always the same for all requests and all layers. + torch::Tensor faked_kv_cache_block_offsets = torch::arange( + 0, num_requests * 2 * max_block_num, torch::TensorOptions().dtype(torch::kInt32).device(output.device())); + + faked_kv_cache_block_offsets = faked_kv_cache_block_offsets.view({num_requests, 2, max_block_num}); + + return faked_kv_cache_block_offsets; +} + +void MLARopeAppendPagedKVAssignQ(torch::Tensor& q, torch::Tensor& latent_cache, int64_t const num_contexts, torch::Tensor const& cu_ctx_cached_kv_lens, torch::Tensor const& cu_seq_lens, int64_t const max_input_uncached_seq_len, torch::Tensor const& cos_sin_cache, int64_t const head_num, int64_t const nope_size, int64_t const rope_size, int64_t const lora_size, @@ -391,6 +544,35 @@ void MLARopeAppendPagedKVAssignQ(torch::Tensor& q, torch::Tensor const& latent_c } } +void mergeChunkedAttentionForMLA(torch::Tensor& merged_attn, torch::Tensor const& temp_attn, + torch::Tensor& merged_softmax_stats, torch::Tensor const& temp_softmax_stats, int64_t const num_requests, + torch::Tensor const& cu_q_seq_lens, int64_t const max_q_seq_len, torch::Tensor const& merge_op, + int64_t const num_heads, int64_t const head_size) +{ + TORCH_CHECK(merged_attn.numel() > 0); + TORCH_CHECK(temp_attn.numel() > 0); + TORCH_CHECK(merged_attn.scalar_type() == temp_attn.scalar_type()); + TORCH_CHECK(merged_attn.scalar_type() == torch::kFloat16 || merged_attn.scalar_type() == torch::kFloat32 + || merged_attn.scalar_type() == torch::kBFloat16); + TORCH_CHECK(temp_softmax_stats.scalar_type() == merged_softmax_stats.scalar_type()); + TORCH_CHECK(merged_softmax_stats.scalar_type() == torch::kFloat32); + + if (merged_attn.scalar_type() == torch::kFloat16) + { + mergeChunkedAttentionForMLAHelper(merged_attn, temp_attn, merged_softmax_stats, temp_softmax_stats, + num_requests, cu_q_seq_lens, max_q_seq_len, merge_op, num_heads, head_size); + } + else if (merged_attn.scalar_type() == torch::kFloat32) + { + mergeChunkedAttentionForMLAHelper(merged_attn, temp_attn, merged_softmax_stats, temp_softmax_stats, + num_requests, cu_q_seq_lens, max_q_seq_len, merge_op, num_heads, head_size); + } + else if (merged_attn.scalar_type() == torch::kBFloat16) + { + mergeChunkedAttentionForMLAHelper<__nv_bfloat16>(merged_attn, temp_attn, merged_softmax_stats, + temp_softmax_stats, num_requests, cu_q_seq_lens, max_q_seq_len, merge_op, num_heads, head_size); + } +} } // namespace torch_ext TORCH_LIBRARY_FRAGMENT(trtllm, m) @@ -424,6 +606,37 @@ TORCH_LIBRARY_IMPL(trtllm, CUDA, m) m.impl("load_paged_kv_cache_for_mla", &torch_ext::loadPagedKVCacheForMLA); } +TORCH_LIBRARY_FRAGMENT(trtllm, m) +{ + m.def( + "load_chunked_kv_cache_for_mla(" + "ScalarType out_dtype" + ", int num_contexts" + ", int num_ctx_cached_tokens" + ", Tensor cu_ctx_chunked_kv_lens" + ", Tensor kv_cache_block_offsets" + ", Tensor host_kv_cache_pool_pointers" + ", Tensor host_kv_cache_pool_mapping" + ", Tensor? kv_scale_orig_quant" + ", Tensor? kv_scale_quant_orig" + ", int layer_idx" + ", int lora_size" + ", int rope_size" + ", int tokens_per_block" + ", int chunked_size" + ", int chunked_index" + ", int attention_window_size" + ", int sink_token_length" + ", int beam_width" + ", int quant_mode" + ") -> Tensor[]"); +} + +TORCH_LIBRARY_IMPL(trtllm, CUDA, m) +{ + m.impl("load_chunked_kv_cache_for_mla", &torch_ext::loadChunkedKVCacheForMLA); +} + TORCH_LIBRARY_FRAGMENT(trtllm, m) { m.def( @@ -447,6 +660,28 @@ TORCH_LIBRARY_IMPL(trtllm, CUDA, m) m.impl("set_paged_kv_cache_for_mla", &torch_ext::setPagedKVCacheForMLA); } +TORCH_LIBRARY_FRAGMENT(trtllm, m) +{ + m.def( + "set_chunked_kv_cache_for_mla(" + "Tensor output" + ", Tensor kv" + ", Tensor k_pe" + ", int num_requests" + ", Tensor cu_seq_lens" + ", int num_heads" + ", int kv_dim" + ", int rope_dim" + ", int kv_cache_tokens_per_block" + ", int max_seq_len" + ") -> Tensor"); +} + +TORCH_LIBRARY_IMPL(trtllm, CUDA, m) +{ + m.impl("set_chunked_kv_cache_for_mla", &torch_ext::setChunkedKVCacheForMLA); +} + TORCH_LIBRARY_FRAGMENT(trtllm, m) { m.def( @@ -481,3 +716,25 @@ TORCH_LIBRARY_IMPL(trtllm, CUDA, m) { m.impl("mla_rope_append_paged_kv_assign_q", &torch_ext::MLARopeAppendPagedKVAssignQ); } + +TORCH_LIBRARY_FRAGMENT(trtllm, m) +{ + m.def( + "merge_chunked_attention_for_mla(" + "Tensor merged_attn" + ", Tensor temp_attn" + ", Tensor merged_softmax_stats" + ", Tensor temp_softmax_stats" + ", int num_requests" + ", Tensor cu_q_seq_lens" + ", int max_q_seq_len" + ", Tensor merge_op" + ", int num_heads" + ", int head_size" + ") -> ()"); +} + +TORCH_LIBRARY_IMPL(trtllm, CUDA, m) +{ + m.impl("merge_chunked_attention_for_mla", &torch_ext::mergeChunkedAttentionForMLA); +} diff --git a/cpp/tests/unit_tests/kernels/CMakeLists.txt b/cpp/tests/unit_tests/kernels/CMakeLists.txt index 2919286ab18..cc04f9ce96b 100644 --- a/cpp/tests/unit_tests/kernels/CMakeLists.txt +++ b/cpp/tests/unit_tests/kernels/CMakeLists.txt @@ -30,6 +30,8 @@ add_gtest(mlaPreprocessTest mlaPreprocessTest.cu) add_gtest(cudaCoreGemmKernelTest cudaCoreGemm/cudaCoreGemmKernelTest.cpp) +add_gtest(mlaChunkedPrefillTest mlaChunkedPrefillTest.cu) + if(NOT ENABLE_MULTI_DEVICE EQUAL 0) add_gtest(allReduceKernelTest allReduce/allReduceKernelTest.cu) add_gtest(allReduceFusionTest allReduce/allReduceFusionTest.cu) diff --git a/cpp/tests/unit_tests/kernels/mlaChunkedPrefillTest.cu b/cpp/tests/unit_tests/kernels/mlaChunkedPrefillTest.cu new file mode 100644 index 00000000000..aba26283f7f --- /dev/null +++ b/cpp/tests/unit_tests/kernels/mlaChunkedPrefillTest.cu @@ -0,0 +1,1079 @@ +#include +#include +#include + +#include "tensorrt_llm/common/cudaUtils.h" +#include "tensorrt_llm/common/quantization.h" +#include "tensorrt_llm/kernels/decodingCommon.h" +#include "tensorrt_llm/kernels/kvCacheUtils.h" +#include "tensorrt_llm/runtime/bufferManager.h" + +#include "tensorrt_llm/kernels/mlaChunkedPrefill.cuh" +#include "tensorrt_llm/runtime/cudaStream.h" +#include +#include +#include +#include + +// #define TRTLLM_MLA_CHUNKED_PREFILL_TEST_DBG + +namespace +{ +// kv_output {total_tokens, h=1, lora_size} +// k_pe_output {total_tokens, h=1, rope_size} +template +void loadChunkedKVKernelRef(T* kv_output, T* k_pe_output, tensorrt_llm::kernels::KVBlockArray const& kv_cache, + int num_contexts, int64_t const* cu_ctx_chunked_len, int const lora_size, int const rope_size, int const chunk_size, + int const chunk_idx) +{ + int const head_size = lora_size + rope_size; + for (int b = 0; b < num_contexts; b++) + { + int const chunked_len = cu_ctx_chunked_len[b + 1] - cu_ctx_chunked_len[b]; + for (int s = 0; s < chunked_len; s++) + { + int const local_token_idx = chunk_idx * chunk_size + s; + int const ld_token_offset = (cu_ctx_chunked_len[b] + s); + + auto const* kv_src = reinterpret_cast(kv_cache.getKBlockPtr(b, local_token_idx)); + for (int d = 0; d < head_size; d++) + { + auto kv_block_idx = kv_cache.getKVLocalIdx(local_token_idx, 0, head_size, d); + auto src_data = kv_src[kv_block_idx]; + + if (d < lora_size) + { + kv_output[ld_token_offset * lora_size + d] = src_data; + } + else + { + k_pe_output[ld_token_offset * rope_size + (d - lora_size)] = src_data; + } + } + } + } +} + +// kv {total_tokens, 2, h, nope_size} +// k_pe {total_tokens, h=1, rope_size} +// output {b, 2, ceil(max_seq / cache_tokens_per_block), h, cache_tokens_per_block, (nope_size + rope_size)} +// max_seq <= chunk_size +template +void setChunkedKVCacheForMLAKernelRef(T* output, T* kv_ptr, T* k_pe_ptr, int num_contexts, int64_t const* cu_seq_len, + int const max_input_seq_len, int num_heads, int nope_size, int rope_size, int cache_tokens_per_block) +{ + int head_size = nope_size + rope_size; + int const kv_cache_size_per_block = num_heads * cache_tokens_per_block * head_size; + int const kv_cache_block_num_per_seq = (max_input_seq_len + cache_tokens_per_block - 1) / cache_tokens_per_block; + for (int b = 0; b < num_contexts; b++) + { + int const global_token_offset = cu_seq_len[b]; + int const current_seq_len = cu_seq_len[b + 1] - cu_seq_len[b]; + for (int s = 0; s < current_seq_len; s++) + { + int const global_token_idx = global_token_offset + s; + int const kv_cache_block_offset_for_k + = (b * 2 * kv_cache_block_num_per_seq + s / cache_tokens_per_block) * kv_cache_size_per_block; + int const kv_cache_block_offset_for_v + = kv_cache_block_offset_for_k + (kv_cache_block_num_per_seq * kv_cache_size_per_block); + for (int h = 0; h < num_heads; h++) + { + int const ld_k_head_offset = (global_token_idx * 2 * num_heads * nope_size) + h * nope_size; + int const ld_v_head_offset = ld_k_head_offset + num_heads * nope_size; + int const ld_k_pe_head_offset = global_token_idx * rope_size; + // copy kv + for (int d = 0; d < nope_size; d++) + { + int const ld_k_idx = ld_k_head_offset + d; + int const ld_v_idx = ld_v_head_offset + d; + int const st_k_idx = kv_cache_block_offset_for_k + h * cache_tokens_per_block * head_size + + (s % cache_tokens_per_block) * head_size + d; + int const st_v_idx = kv_cache_block_offset_for_v + h * cache_tokens_per_block * head_size + + (s % cache_tokens_per_block) * head_size + d; + output[st_k_idx] = kv_ptr[ld_k_idx]; + output[st_v_idx] = kv_ptr[ld_v_idx]; + } + + // copy k_pe + for (int d = 0; d < rope_size; d++) + { + int const ld_k_pe_idx = ld_k_pe_head_offset + d; + int const st_k_pe_idx = kv_cache_block_offset_for_k + h * cache_tokens_per_block * head_size + + (s % cache_tokens_per_block) * head_size + (nope_size + d); + output[st_k_pe_idx] = k_pe_ptr[ld_k_pe_idx]; + } + } + } + } +} + +// Q {total_q, H, D} +// KV {total_kv, 2, H, D} +// softmax_sum {total_q, H, 2} // {max/sum} +// output {total_q, H, D} +// total_q <= total_kv +template +void selfAttentionRef(T* output, T* const Q, T* const KV, int batch_size, int num_heads, int64_t* const cu_seq_q_len, + int64_t* const cu_seq_kv_len, int head_size, bool return_softmax, float* softmax_sum, bool causal_mask) +{ + for (int b = 0; b < batch_size; b++) + { + int curr_q_len = cu_seq_q_len[b + 1] - cu_seq_q_len[b]; + int curr_kv_len = cu_seq_kv_len[b + 1] - cu_seq_kv_len[b]; + int global_q_offset = cu_seq_q_len[b] * num_heads * head_size; + int global_kv_offset = cu_seq_kv_len[b] * 2 * num_heads * head_size; + int global_softmax_offset = cu_seq_q_len[b] * num_heads * 2; + float bmm1_scale = 1.F / std::sqrt(static_cast(head_size)); + if (curr_q_len == 0 || curr_kv_len == 0) + { + continue; // skip empty sequences + } + std::vector P(curr_q_len * curr_kv_len); + for (int h = 0; h < num_heads; h++) + { + // BMM1 + std::fill(P.begin(), P.end(), std::numeric_limits::lowest()); + T* const q_ptr = Q + global_q_offset + h * head_size; + T* const k_ptr = KV + global_kv_offset + h * head_size; + T* const v_ptr = k_ptr + num_heads * head_size; + T* output_ptr = output + global_q_offset + h * head_size; + for (int s_q = 0; s_q < curr_q_len; s_q++) + { + float softmax_max = std::numeric_limits::lowest(); + for (int s_kv = 0; s_kv < curr_kv_len; s_kv++) + { + // lower right mask + if (causal_mask && s_kv > curr_kv_len - curr_q_len + s_q) + { + break; + } + P[s_q * curr_kv_len + s_kv] = 0; + for (int d = 0; d < head_size; d++) + { + P[s_q * curr_kv_len + s_kv] += static_cast( + q_ptr[s_q * num_heads * head_size + d] * k_ptr[s_kv * 2 * num_heads * head_size + d]); + } + if (softmax_max < P[s_q * curr_kv_len + s_kv]) + { + softmax_max = P[s_q * curr_kv_len + s_kv]; + } + } + for (int s_kv = 0; s_kv < curr_kv_len; s_kv++) + { + // lower right mask + if (causal_mask && s_kv > curr_kv_len - curr_q_len + s_q) + { + break; + } + P[s_q * curr_kv_len + s_kv] -= softmax_max; + } + if (return_softmax) + { + softmax_sum[global_softmax_offset + s_q * num_heads * 2 + h * 2] = softmax_max; + } + } + // softmax + for (int s_q = 0; s_q < curr_q_len; s_q++) + { + float sum = 0; + for (int s_kv = 0; s_kv < curr_kv_len; s_kv++) + { + // P[s_q * curr_kv_len + s_kv] = std::exp(P[s_q * curr_kv_len + s_kv] * bmm1_scale); + // hack for real mla kernel + P[s_q * curr_kv_len + s_kv] = std::exp(P[s_q * curr_kv_len + s_kv] * 0.072168784); + sum += P[s_q * curr_kv_len + s_kv]; + } + for (int s_kv = 0; s_kv < curr_kv_len; s_kv++) + { + P[s_q * curr_kv_len + s_kv] /= sum; + } + if (return_softmax) + { + softmax_sum[global_softmax_offset + s_q * num_heads * 2 + h * 2 + 1] = sum; + } + } + // BMM2 + for (int s_q = 0; s_q < curr_q_len; s_q++) + { + for (int d = 0; d < head_size; d++) + { + output_ptr[s_q * num_heads * head_size + d] = 0; + for (int s_kv = 0; s_kv < curr_kv_len; s_kv++) + { + output_ptr[s_q * num_heads * head_size + d] += static_cast(P[s_q * curr_kv_len + s_kv] + * static_cast(v_ptr[s_kv * 2 * num_heads * head_size + d])); + } + } + } + } + } +} + +// chunked_KV {total_chunk_token, 2, H, D} +// KV {total_kv_token, 2, H, D} +template +void copyRelatedChunkedKV(T* chunked_kv, T* const kv, int chunk_idx, int chunk_size, int batch_size, int num_heads, + int64_t* const cu_kv_seq_len, int64_t* const cu_chunked_seq_len, int head_size) +{ + for (int b = 0; b < batch_size; b++) + { + int src_global_offset = (cu_kv_seq_len[b] + chunk_idx * chunk_size) * 2 * num_heads * head_size; + int dst_global_offset = cu_chunked_seq_len[b] * 2 * num_heads * head_size; + int copy_length = cu_chunked_seq_len[b + 1] - cu_chunked_seq_len[b]; + if (copy_length <= 0) + { + continue; // skip empty sequences + } + + std::memcpy(chunked_kv + dst_global_offset, kv + src_global_offset, + copy_length * 2 * num_heads * head_size * sizeof(T)); + } +} + +// chunked_KV {total_chunk_token, 2, H, D} +// KV {total_kv_token, 2, H, D} +// It will copy the last chunk of KV cache to chunked_KV cache and calculate the cu_chunked_seq_len +template +void copyFinalChunkedKV(T* chunked_kv, T* const kv, int chunk_size, int batch_size, int num_heads, + int64_t* const cu_kv_seq_len, int64_t* cu_chunked_seq_len, int head_size, int64_t* merge_op) +{ + cu_chunked_seq_len[0] = 0; + for (int b = 0; b < batch_size; b++) + { + int curr_kv_len = cu_kv_seq_len[b + 1] - cu_kv_seq_len[b]; + int last_chunk_size = curr_kv_len % chunk_size; + if (last_chunk_size == 0) + { + last_chunk_size = chunk_size; // ensure at least one chunk + } + if (last_chunk_size == curr_kv_len) + { + merge_op[b] = 2; // no need to merge, just copy + } + else + { + merge_op[b] = 1; + } + cu_chunked_seq_len[b + 1] = cu_chunked_seq_len[b] + last_chunk_size; + int global_token_offset = cu_kv_seq_len[b] + curr_kv_len - last_chunk_size; + int copy_length = last_chunk_size; + if (copy_length <= 0) + { + printf("copy_length is zero for batch %d, skipping...\n", b); + continue; // skip empty sequences + } + int src_global_offset = global_token_offset * 2 * num_heads * head_size; + int dst_global_offset = cu_chunked_seq_len[b] * 2 * num_heads * head_size; + std::memcpy(chunked_kv + dst_global_offset, kv + src_global_offset, + copy_length * 2 * num_heads * head_size * sizeof(T)); + } +} + +template +float getTolerance(float scale = 1.f) +{ + float tol = 0.0; + if constexpr (std::is_same_v) + { + tol = 0.1; + } + else if constexpr (std::is_same_v) + { + tol = 0.001; + } + else if constexpr (std::is_same_v) + { + tol = 0.005; + } + else if constexpr (std::is_same_v) + { + tol = 0.05; + } + // Keep the scale in a sane range + return std::max(tol, scale * tol); +} +}; // namespace + +template +class MlaChunkedPrefillTest : public ::testing::Test +{ +protected: + using DataType = _DataType; + + std::shared_ptr mStream; + + tensorrt_llm::runtime::BufferManager::ITensorPtr h_kv_cache_tensor{nullptr}, h_kv_cache_tensor_ref{nullptr}, + d_kv_cache_tensor{nullptr}, h_compressed_kv_cache_tensor{nullptr}, d_compressed_kv_cache_tensor{nullptr}, + h_compressed_offset_tensor{nullptr}, d_compressed_offset_tensor{nullptr}, h_cu_kv_seq_lens{nullptr}, + d_cu_kv_seq_lens{nullptr}, h_cu_chunk_lens{nullptr}, d_cu_chunk_lens{nullptr}, h_cu_q_seq_lens{nullptr}, + d_cu_q_seq_lens{nullptr}, + + // for kernel 1 + h_compressed_kv_output{nullptr}, d_compressed_kv_output{nullptr}, h_k_pe_output{nullptr}, + d_k_pe_output{nullptr}, h_compressed_kv_output_ref{nullptr}, h_k_pe_output_ref{nullptr}, + + // for kernel 2 + h_kv_tensor{nullptr}, d_kv_tensor{nullptr}, h_k_pe_tensor{nullptr}, d_k_pe_tensor{nullptr}, + + // for merge attn {kv_full_tensor = kv + k_pe} + m_h_q_tensor{nullptr}, m_h_kv_full_tensor{nullptr}, m_h_chunked_kv_tensor{nullptr}, m_h_output_tensor{nullptr}, + m_h_softmax_sum_tensor{nullptr}, m_h_softmax_sum_accum_tensor{nullptr}, m_h_output_tensor_ref{nullptr}, + m_h_output_tensor_accum{nullptr}, m_d_q_tensor{nullptr}, m_d_kv_full_tensor{nullptr}, + m_d_chunked_kv_tensor{nullptr}, m_d_output_tensor{nullptr}, m_d_softmax_sum_tensor{nullptr}, + m_d_softmax_sum_accum_tensor{nullptr}, m_d_output_tensor_accum{nullptr}, m_h_merge_op{nullptr}, + m_d_merge_op{nullptr}; + + int mBatchSize{}; + int mMaxSeqLen{}; + int mMaxQSeqLen{}; + int mTotalQLen{}; + int mTotalKVLen{}; + int mChunkSize{}; + int mNumHeads{}; + int mLoraSize{}; + int mRopeSize{}; + int mNopeSize{}; + int mMaxGenLength{}; + // int mHeadSize{}; + int mTokensPerBlock{}; + int mMaxBlockPerSeq{}; + bool mIsCausalMask{}; + + std::mt19937 gen; + + void SetUp() override + { + if (shouldSkip()) + { + GTEST_SKIP() << "Skipping mla chunked prefill test"; + } + mStream = std::make_shared(); + gen.seed(42U); + } + + static bool shouldSkip() + { + return false; + } + + void setDefaultParams() + { + mBatchSize = 16; + // mMaxSeqLen = 128; + mChunkSize = 16; + mNumHeads = 16; + mLoraSize = 512; + mRopeSize = 64; + mNopeSize = 128; + mIsCausalMask = false; + mMaxGenLength = 128; + mTokensPerBlock = 16; + assert(this->mChunkSize % this->mTokensPerBlock == 0); + } + + void memsetZeroHost(tensorrt_llm::runtime::BufferManager::ITensorPtr& tensor) + { + void* ptr = tensor->data(); + std::memset(ptr, 0, tensor->getSizeInBytes()); + } + + template + void showHostTensor(tensorrt_llm::runtime::BufferManager::ITensorPtr& tensor) + { + auto* const ptr = reinterpret_cast(tensor->data()); + for (int _ = 0; _ < tensor->getSize(); _++) + { + std::cout << static_cast(ptr[_]) << " "; + } + std::cout << std::endl; + } + + int generateRandomSizeSmallerThan(int a) + { + if (a <= 0) + { + return 0; + } + std::uniform_int_distribution<> distrib(0, a - 1); + // Generate and return the random number + return int{distrib(gen)}; + } + + float generateRandomFloat(float min, float max) + { + std::uniform_real_distribution dist(min, max); + return dist(gen); + } + + template + void generateRandomData(T* data, int size) + { + for (int i = 0; i < size; i++) + { + data[i] = static_cast(generateRandomFloat(-1.0f, 1.0f)); + } + } + + template + void fillKVOffsetData(T* arr, size_t size, bool use_both_kv = true, int max_block_per_seq = 0) + { + if (use_both_kv) + { + for (int i = 0; i < size; i++) + { + arr[i] = static_cast(i); + } + } + else + { + int temp_idx = 0; + for (int i = 0; i < size; i++) + { + bool is_v = (((i / max_block_per_seq) % 2) == 1); + if (is_v) + { + arr[i] = static_cast(0); + } + else + { + arr[i] = static_cast(temp_idx); + temp_idx++; + } + } + } + } + + template + void fillArrayDataWithMod(T* arr, size_t size) + { + for (int i = 0; i < size; i++) + { + arr[i] = static_cast(i % 448); + } + } + + bool allocateBuffers() + { + using tensorrt_llm::runtime::BufferManager; + using tensorrt_llm::runtime::CudaStream; + using tensorrt_llm::runtime::ITensor; + using tensorrt_llm::runtime::bufferCast; + + auto dtype = nvinfer1::DataType::kHALF; + if constexpr (std::is_same_v) + { + dtype = nvinfer1::DataType::kFLOAT; + } + else if constexpr (std::is_same_v) + { + dtype = nvinfer1::DataType::kHALF; + } + else if constexpr (std::is_same_v) + { + dtype = nvinfer1::DataType::kBF16; + } + else + { + return false; + } + + // cu lens + this->h_cu_kv_seq_lens = tensorrt_llm::runtime::BufferManager::pinned( + ITensor::makeShape({this->mBatchSize + 1}), nvinfer1::DataType::kINT64); + this->h_cu_chunk_lens = tensorrt_llm::runtime::BufferManager::pinned( + ITensor::makeShape({this->mBatchSize + 1}), nvinfer1::DataType::kINT64); + this->h_cu_q_seq_lens = tensorrt_llm::runtime::BufferManager::pinned( + ITensor::makeShape({this->mBatchSize + 1}), nvinfer1::DataType::kINT64); + this->d_cu_kv_seq_lens = tensorrt_llm::runtime::BufferManager::gpuSync( + this->h_cu_kv_seq_lens->getShape(), nvinfer1::DataType::kINT64); + this->d_cu_chunk_lens = tensorrt_llm::runtime::BufferManager::gpuSync( + this->h_cu_chunk_lens->getShape(), nvinfer1::DataType::kINT64); + this->d_cu_q_seq_lens = tensorrt_llm::runtime::BufferManager::gpuSync( + this->h_cu_q_seq_lens->getShape(), nvinfer1::DataType::kINT64); + { + this->mMaxSeqLen = 0; + this->mMaxQSeqLen = 0; + this->mTotalQLen = 0; + this->mTotalKVLen = 0; + // we only initialize cu_seq_lens + auto* cu_kv_seq_lens_ptr = bufferCast(*(this->h_cu_kv_seq_lens)); + auto* cu_q_seq_lens_ptr = bufferCast(*(this->h_cu_q_seq_lens)); + cu_kv_seq_lens_ptr[0] = 0; + cu_q_seq_lens_ptr[0] = 0; + for (int i = 0; i < this->mBatchSize; i++) + { + int temp_seq_len = this->generateRandomSizeSmallerThan(this->mMaxGenLength); + if (temp_seq_len == 0) + { + temp_seq_len = 1; // ensure at least one token + } + this->mMaxSeqLen = std::max(this->mMaxSeqLen, temp_seq_len); + cu_kv_seq_lens_ptr[i + 1] = cu_kv_seq_lens_ptr[i] + temp_seq_len; + auto temp_q_seq_len = temp_seq_len % this->mChunkSize; + if (temp_q_seq_len == 0) + { + temp_q_seq_len = this->mChunkSize; // ensure at least one chunk + } + cu_q_seq_lens_ptr[i + 1] = cu_q_seq_lens_ptr[i] + temp_q_seq_len; + this->mMaxQSeqLen = std::max(this->mMaxQSeqLen, temp_q_seq_len); + this->mTotalQLen += temp_q_seq_len; + this->mTotalKVLen += temp_seq_len; + } + cudaMemcpy(this->d_cu_kv_seq_lens->data(), this->h_cu_kv_seq_lens->data(), + this->h_cu_kv_seq_lens->getSizeInBytes(), cudaMemcpyHostToDevice); + cudaMemcpy(this->d_cu_q_seq_lens->data(), this->h_cu_q_seq_lens->data(), + this->h_cu_q_seq_lens->getSizeInBytes(), cudaMemcpyHostToDevice); +#ifdef TRTLLM_MLA_CHUNKED_PREFILL_TEST_DBG + this->showHostTensor(this->h_cu_q_seq_lens); + this->showHostTensor(this->h_cu_kv_seq_lens); +#endif + } + // kv cache + this->mMaxBlockPerSeq = (this->mMaxSeqLen + this->mTokensPerBlock - 1) / this->mTokensPerBlock; + int maxChunkBlockPerSeq = (this->mChunkSize + this->mTokensPerBlock - 1) / this->mTokensPerBlock; + this->h_kv_cache_tensor = tensorrt_llm::runtime::BufferManager::pinned( + ITensor::makeShape({this->mBatchSize, 2, maxChunkBlockPerSeq, this->mNumHeads, this->mTokensPerBlock, + this->mNopeSize + this->mRopeSize}), + dtype); + + this->h_kv_cache_tensor_ref = tensorrt_llm::runtime::BufferManager::pinned( + ITensor::makeShape({this->mBatchSize, 2, maxChunkBlockPerSeq, this->mNumHeads, this->mTokensPerBlock, + this->mNopeSize + this->mRopeSize}), + dtype); + + this->h_compressed_kv_cache_tensor = tensorrt_llm::runtime::BufferManager::pinned( + ITensor::makeShape({this->mBatchSize, 2, this->mMaxBlockPerSeq, this->mNumHeads, this->mTokensPerBlock, + this->mLoraSize + this->mRopeSize}), + dtype); + this->h_compressed_offset_tensor = tensorrt_llm::runtime::BufferManager::pinned( + ITensor::makeShape({this->mBatchSize, 2, this->mMaxBlockPerSeq + 1}), nvinfer1::DataType::kINT32); + this->d_kv_cache_tensor + = tensorrt_llm::runtime::BufferManager::gpuSync(this->h_kv_cache_tensor->getShape(), dtype); + this->d_compressed_kv_cache_tensor + = tensorrt_llm::runtime::BufferManager::gpuSync(this->h_compressed_kv_cache_tensor->getShape(), dtype); + this->d_compressed_offset_tensor = tensorrt_llm::runtime::BufferManager::gpuSync( + this->h_compressed_offset_tensor->getShape(), nvinfer1::DataType::kINT32); + + { + auto* compressed_kv_cache_ptr = bufferCast(*(this->h_compressed_kv_cache_tensor)); + auto* offset_ptr = bufferCast(*(this->h_compressed_offset_tensor)); + + this->memsetZeroHost(this->h_kv_cache_tensor); + this->memsetZeroHost(this->h_kv_cache_tensor_ref); + + this->fillArrayDataWithMod(compressed_kv_cache_ptr, this->h_compressed_kv_cache_tensor->getSize()); + this->fillKVOffsetData( + offset_ptr, this->h_compressed_offset_tensor->getSize(), false, this->mMaxBlockPerSeq); + cudaMemcpy(this->d_kv_cache_tensor->data(), this->h_kv_cache_tensor->data(), + this->h_kv_cache_tensor->getSizeInBytes(), cudaMemcpyHostToDevice); + cudaMemcpy(this->d_compressed_kv_cache_tensor->data(), this->h_compressed_kv_cache_tensor->data(), + this->h_compressed_kv_cache_tensor->getSizeInBytes(), cudaMemcpyHostToDevice); + cudaMemcpy(this->d_compressed_offset_tensor->data(), this->h_compressed_offset_tensor->data(), + this->h_compressed_offset_tensor->getSizeInBytes(), cudaMemcpyHostToDevice); + } + + // tensor + // kv, k_pe for invokeMLALoadChunkedKV (kernel 1) + this->h_compressed_kv_output = tensorrt_llm::runtime::BufferManager::pinned( + ITensor::makeShape({this->mBatchSize * this->mChunkSize, 1, this->mLoraSize}), dtype); + this->h_k_pe_output = tensorrt_llm::runtime::BufferManager::pinned( + ITensor::makeShape({this->mBatchSize * this->mChunkSize, 1, this->mRopeSize}), dtype); + this->h_compressed_kv_output_ref = tensorrt_llm::runtime::BufferManager::pinned( + ITensor::makeShape({this->mBatchSize * this->mChunkSize, 1, this->mLoraSize}), dtype); + this->h_k_pe_output_ref = tensorrt_llm::runtime::BufferManager::pinned( + ITensor::makeShape({this->mBatchSize * this->mChunkSize, 1, this->mRopeSize}), dtype); + this->d_compressed_kv_output + = tensorrt_llm::runtime::BufferManager::gpuSync(this->h_compressed_kv_output->getShape(), dtype); + this->d_k_pe_output = tensorrt_llm::runtime::BufferManager::gpuSync(this->h_k_pe_output->getShape(), dtype); + { + this->memsetZeroHost(this->h_compressed_kv_output); + this->memsetZeroHost(this->h_k_pe_output); + this->memsetZeroHost(this->h_compressed_kv_output_ref); + this->memsetZeroHost(this->h_k_pe_output_ref); + + cudaMemcpy(this->d_compressed_kv_output->data(), this->h_compressed_kv_output->data(), + this->h_compressed_kv_output->getSizeInBytes(), cudaMemcpyHostToDevice); + cudaMemcpy(this->d_k_pe_output->data(), this->h_k_pe_output->data(), this->h_k_pe_output->getSizeInBytes(), + cudaMemcpyHostToDevice); + } + + // kv, k_pe for invokeMLASetChunkedKV (kernel 2) + this->h_kv_tensor = tensorrt_llm::runtime::BufferManager::pinned( + ITensor::makeShape({this->mBatchSize * this->mChunkSize, 2, this->mNumHeads, this->mNopeSize}), dtype); + this->h_k_pe_tensor = tensorrt_llm::runtime::BufferManager::pinned( + ITensor::makeShape({this->mBatchSize * this->mChunkSize, 1, this->mRopeSize}), dtype); + this->d_kv_tensor = tensorrt_llm::runtime::BufferManager::gpuSync(this->h_kv_tensor->getShape(), dtype); + this->d_k_pe_tensor = tensorrt_llm::runtime::BufferManager::gpuSync(this->h_k_pe_tensor->getShape(), dtype); + { + auto* kv_ptr = bufferCast(*(this->h_kv_tensor)); + auto* k_pe_ptr = bufferCast(*(this->h_k_pe_tensor)); + + fillArrayDataWithMod(kv_ptr, h_kv_tensor->getSize()); + fillArrayDataWithMod(k_pe_ptr, h_k_pe_tensor->getSize()); + + cudaMemcpyAsync(d_kv_tensor->data(), h_kv_tensor->data(), h_kv_tensor->getSizeInBytes(), + cudaMemcpyHostToDevice, mStream->get()); + cudaMemcpyAsync(d_k_pe_tensor->data(), h_k_pe_tensor->data(), h_k_pe_tensor->getSizeInBytes(), + cudaMemcpyHostToDevice, mStream->get()); + cudaStreamSynchronize(mStream->get()); + } + + // invokeMergeAttnWithSoftmax, we just ignore rope_size here for simplicity + + this->m_h_q_tensor = tensorrt_llm::runtime::BufferManager::pinned( + ITensor::makeShape({this->mTotalQLen, this->mNumHeads, this->mNopeSize}), dtype); + this->m_h_kv_full_tensor = tensorrt_llm::runtime::BufferManager::pinned( + ITensor::makeShape({this->mTotalKVLen, 2, this->mNumHeads, this->mNopeSize}), dtype); + this->m_h_chunked_kv_tensor = tensorrt_llm::runtime::BufferManager::pinned( + ITensor::makeShape({this->mBatchSize * this->mChunkSize, 2, this->mNumHeads, this->mNopeSize}), dtype); + this->m_h_output_tensor = tensorrt_llm::runtime::BufferManager::pinned( + ITensor::makeShape({this->mTotalQLen, this->mNumHeads, this->mNopeSize}), dtype); + this->m_h_softmax_sum_tensor = tensorrt_llm::runtime::BufferManager::pinned( + ITensor::makeShape({2, this->mTotalQLen, this->mNumHeads}), nvinfer1::DataType::kFLOAT); + this->m_h_softmax_sum_accum_tensor = tensorrt_llm::runtime::BufferManager::pinned( + ITensor::makeShape({2, this->mTotalQLen, this->mNumHeads}), nvinfer1::DataType::kFLOAT); + this->m_h_output_tensor_ref = tensorrt_llm::runtime::BufferManager::pinned( + ITensor::makeShape({this->mTotalQLen, this->mNumHeads, this->mNopeSize}), dtype); + this->m_h_output_tensor_accum = tensorrt_llm::runtime::BufferManager::pinned( + ITensor::makeShape({this->mTotalQLen, this->mNumHeads, this->mNopeSize}), dtype); + this->m_h_merge_op = tensorrt_llm::runtime::BufferManager::pinned( + ITensor::makeShape({this->mBatchSize}), nvinfer1::DataType::kINT64); + this->m_d_q_tensor = tensorrt_llm::runtime::BufferManager::gpuSync(this->m_h_q_tensor->getShape(), dtype); + this->m_d_kv_full_tensor + = tensorrt_llm::runtime::BufferManager::gpuSync(this->m_h_kv_full_tensor->getShape(), dtype); + this->m_d_chunked_kv_tensor + = tensorrt_llm::runtime::BufferManager::gpuSync(this->m_h_chunked_kv_tensor->getShape(), dtype); + this->m_d_output_tensor + = tensorrt_llm::runtime::BufferManager::gpuSync(this->m_h_output_tensor->getShape(), dtype); + this->m_d_softmax_sum_tensor = tensorrt_llm::runtime::BufferManager::gpuSync( + this->m_h_softmax_sum_tensor->getShape(), nvinfer1::DataType::kFLOAT); + this->m_d_softmax_sum_accum_tensor = tensorrt_llm::runtime::BufferManager::gpuSync( + this->m_h_softmax_sum_accum_tensor->getShape(), nvinfer1::DataType::kFLOAT); + this->m_d_output_tensor_accum + = tensorrt_llm::runtime::BufferManager::gpuSync(this->m_h_output_tensor_accum->getShape(), dtype); + this->m_d_merge_op + = tensorrt_llm::runtime::BufferManager::gpuSync(this->m_h_merge_op->getShape(), nvinfer1::DataType::kINT64); + + { + auto* q_ptr = bufferCast(*(this->m_h_q_tensor)); + auto* kv_ptr = bufferCast(*(this->m_h_kv_full_tensor)); + + generateRandomData(q_ptr, m_h_q_tensor->getSize()); + generateRandomData(kv_ptr, m_h_kv_full_tensor->getSize()); + this->memsetZeroHost(m_h_chunked_kv_tensor); + this->memsetZeroHost(m_h_output_tensor); + this->memsetZeroHost(m_h_softmax_sum_tensor); + this->memsetZeroHost(m_h_softmax_sum_accum_tensor); + this->memsetZeroHost(m_h_output_tensor_ref); + this->memsetZeroHost(m_h_output_tensor_accum); + + // Copy data to device + cudaMemcpyAsync(m_d_q_tensor->data(), m_h_q_tensor->data(), m_h_q_tensor->getSizeInBytes(), + cudaMemcpyHostToDevice, mStream->get()); + cudaMemcpyAsync(m_d_kv_full_tensor->data(), m_h_kv_full_tensor->data(), + m_h_kv_full_tensor->getSizeInBytes(), cudaMemcpyHostToDevice, mStream->get()); + cudaMemcpyAsync(m_d_chunked_kv_tensor->data(), m_h_chunked_kv_tensor->data(), + m_h_chunked_kv_tensor->getSizeInBytes(), cudaMemcpyHostToDevice, mStream->get()); + cudaMemcpyAsync(m_d_output_tensor->data(), m_h_output_tensor->data(), m_h_output_tensor->getSizeInBytes(), + cudaMemcpyHostToDevice, mStream->get()); + cudaMemcpyAsync(m_d_softmax_sum_tensor->data(), m_h_softmax_sum_tensor->data(), + m_h_softmax_sum_tensor->getSizeInBytes(), cudaMemcpyHostToDevice, mStream->get()); + cudaMemcpyAsync(m_d_softmax_sum_accum_tensor->data(), m_h_softmax_sum_accum_tensor->data(), + m_h_softmax_sum_accum_tensor->getSizeInBytes(), cudaMemcpyHostToDevice, mStream->get()); + cudaMemcpyAsync(m_d_output_tensor_accum->data(), m_h_output_tensor_accum->data(), + m_h_output_tensor_accum->getSizeInBytes(), cudaMemcpyHostToDevice, mStream->get()); + cudaStreamSynchronize(mStream->get()); + } + return true; + } + + void PerformNormalAttention() + { + using tensorrt_llm::runtime::bufferCast; + + auto* q_ptr = bufferCast(*(this->m_h_q_tensor)); + auto* kv_ptr = bufferCast(*(this->m_h_kv_full_tensor)); + auto* output_ptr = bufferCast(*(this->m_h_output_tensor_ref)); + auto* cu_q_seq_lens_ptr = bufferCast(*(this->h_cu_q_seq_lens)); + auto* cu_kv_seq_lens_ptr = bufferCast(*(this->h_cu_kv_seq_lens)); + selfAttentionRef(output_ptr, q_ptr, kv_ptr, this->mBatchSize, this->mNumHeads, cu_q_seq_lens_ptr, + cu_kv_seq_lens_ptr, this->mNopeSize, false, nullptr, this->mIsCausalMask); + } + + void PerformMergedAttention() + { + using tensorrt_llm::runtime::bufferCast; + + auto* h_q_ptr = bufferCast(*(this->m_h_q_tensor)); + auto* h_kv_ptr = bufferCast(*(this->m_h_kv_full_tensor)); + auto* h_chunked_kv_ptr = bufferCast(*(this->m_h_chunked_kv_tensor)); + auto* h_output_ptr = bufferCast(*(this->m_h_output_tensor)); + auto* h_output_accum_ptr = bufferCast(*(this->m_h_output_tensor_accum)); + auto* h_softmax_sum_ptr = bufferCast(*(this->m_h_softmax_sum_tensor)); + auto* h_softmax_sum_accum_ptr = bufferCast(*(this->m_h_softmax_sum_accum_tensor)); + auto* h_cu_q_seq_lens_ptr = bufferCast(*(this->h_cu_q_seq_lens)); + auto* h_cu_kv_seq_lens_ptr = bufferCast(*(this->h_cu_kv_seq_lens)); + auto* h_cu_chunk_lens_ptr = bufferCast(*(this->h_cu_chunk_lens)); + auto* h_merge_op = bufferCast(*(this->m_h_merge_op)); + auto* d_kv_ptr = bufferCast(*(this->m_d_kv_full_tensor)); + auto* d_chunked_kv_ptr = bufferCast(*(this->m_d_chunked_kv_tensor)); + auto* d_softmax_sum_ptr = bufferCast(*(this->m_d_softmax_sum_tensor)); + auto* d_softmax_sum_accum_ptr = bufferCast(*(this->m_d_softmax_sum_accum_tensor)); + auto* d_output_ptr = bufferCast(*(this->m_d_output_tensor)); + auto* d_output_accum_ptr = bufferCast(*(this->m_d_output_tensor_accum)); + auto* d_merge_op = bufferCast(*(this->m_d_merge_op)); + auto* d_cu_q_seq_lens_ptr = bufferCast(*(this->d_cu_q_seq_lens)); + + int const loop_count = (this->mMaxSeqLen + this->mChunkSize - 1) / this->mChunkSize; + // do not apply mask + for (int _ = 0; _ < loop_count - 1; _++) + { + // get chunked len for each request + this->PrepareChunkedLen(_); + cudaMemcpy(d_merge_op, h_merge_op, this->m_h_merge_op->getSizeInBytes(), cudaMemcpyHostToDevice); + // copy related kv chunk data + copyRelatedChunkedKV(h_chunked_kv_ptr, h_kv_ptr, _, this->mChunkSize, this->mBatchSize, this->mNumHeads, + h_cu_kv_seq_lens_ptr, h_cu_chunk_lens_ptr, this->mNopeSize); + // attention + selfAttentionRef(h_output_ptr, h_q_ptr, h_chunked_kv_ptr, this->mBatchSize, this->mNumHeads, + h_cu_q_seq_lens_ptr, h_cu_chunk_lens_ptr, this->mNopeSize, true, h_softmax_sum_ptr, false); + // merge attention + + // copy curr_attn and softmax_sum to device + cudaMemcpy(d_softmax_sum_ptr, h_softmax_sum_ptr, this->m_h_softmax_sum_tensor->getSizeInBytes(), + cudaMemcpyHostToDevice); + cudaMemcpy(d_output_ptr, h_output_ptr, this->m_h_output_tensor->getSizeInBytes(), cudaMemcpyHostToDevice); + // merge softmax + tensorrt_llm::kernels::invokeMergeAttnWithSoftmax(d_output_accum_ptr, d_softmax_sum_accum_ptr, + d_output_accum_ptr, d_softmax_sum_accum_ptr, d_output_ptr, d_softmax_sum_ptr, this->mBatchSize, + d_cu_q_seq_lens_ptr, this->mMaxQSeqLen, d_merge_op, this->mNumHeads, this->mNopeSize, mStream->get()); + cudaStreamSynchronize(mStream->get()); + // copy merged softmax sum back to host + cudaMemcpy(h_softmax_sum_accum_ptr, d_softmax_sum_accum_ptr, this->m_h_softmax_sum_tensor->getSizeInBytes(), + cudaMemcpyDeviceToHost); + cudaMemcpy(h_output_accum_ptr, d_output_accum_ptr, this->m_h_output_tensor->getSizeInBytes(), + cudaMemcpyDeviceToHost); + } + // final round, apply causal mask. + + // copy the last chunked kv data + copyFinalChunkedKV(h_chunked_kv_ptr, h_kv_ptr, this->mChunkSize, this->mBatchSize, this->mNumHeads, + h_cu_kv_seq_lens_ptr, h_cu_chunk_lens_ptr, this->mNopeSize, h_merge_op); + cudaMemcpy(d_merge_op, h_merge_op, this->m_h_merge_op->getSizeInBytes(), cudaMemcpyHostToDevice); +#ifdef TRTLLM_MLA_CHUNKED_PREFILL_TEST_DBG + std::cout << "merge op: "; + this->showHostTensor(this->m_h_merge_op); + std::cout << "cu chunk lens: "; + this->showHostTensor(this->h_cu_chunk_lens); +#endif + // attention + selfAttentionRef(h_output_ptr, h_q_ptr, h_chunked_kv_ptr, this->mBatchSize, this->mNumHeads, + h_cu_q_seq_lens_ptr, h_cu_chunk_lens_ptr, this->mNopeSize, true, h_softmax_sum_ptr, this->mIsCausalMask); + // merge attention + // copy curr_attn and softmax_sum to device + cudaMemcpy(d_softmax_sum_ptr, h_softmax_sum_ptr, this->m_h_softmax_sum_tensor->getSizeInBytes(), + cudaMemcpyHostToDevice); + cudaMemcpy(d_output_ptr, h_output_ptr, this->m_h_output_tensor->getSizeInBytes(), cudaMemcpyHostToDevice); + tensorrt_llm::kernels::invokeMergeAttnWithSoftmax(d_output_accum_ptr, d_softmax_sum_accum_ptr, + d_output_accum_ptr, d_softmax_sum_accum_ptr, d_output_ptr, d_softmax_sum_ptr, this->mBatchSize, + d_cu_q_seq_lens_ptr, this->mMaxQSeqLen, d_merge_op, this->mNumHeads, this->mNopeSize, mStream->get()); + cudaStreamSynchronize(mStream->get()); + // copy merged softmax sum back to host + cudaMemcpy(h_softmax_sum_accum_ptr, d_softmax_sum_accum_ptr, this->m_h_softmax_sum_tensor->getSizeInBytes(), + cudaMemcpyDeviceToHost); + cudaMemcpy( + h_output_accum_ptr, d_output_accum_ptr, this->m_h_output_tensor->getSizeInBytes(), cudaMemcpyDeviceToHost); + sync_check_cuda_error(mStream->get()); + } + + void PrepareChunkedLen(int chunk_idx) + { + using tensorrt_llm::runtime::bufferCast; + auto* h_merge_op = bufferCast(*(this->m_h_merge_op)); + auto* h_cu_q_seq_lens_ptr = bufferCast(*(this->h_cu_q_seq_lens)); + auto* h_cu_kv_seq_lens_ptr = bufferCast(*(this->h_cu_kv_seq_lens)); + auto* h_cu_chunk_lens_ptr = bufferCast(*(this->h_cu_chunk_lens)); + + h_cu_chunk_lens_ptr[0] = 0; + for (int b = 0; b < this->mBatchSize; b++) + { + int curr_kv_len = h_cu_kv_seq_lens_ptr[b + 1] - h_cu_kv_seq_lens_ptr[b]; + int used_kv_len = chunk_idx * this->mChunkSize; + int curr_chunk_len = std::min(this->mChunkSize, curr_kv_len - used_kv_len); + if (curr_chunk_len != this->mChunkSize) + { + // last chunk, we should skip it. + curr_chunk_len = 0; + } + else + { + if (used_kv_len + curr_chunk_len == curr_kv_len) + { + // last chunk, we should skip it. + curr_chunk_len = 0; + } + } + h_cu_chunk_lens_ptr[b + 1] = h_cu_chunk_lens_ptr[b] + curr_chunk_len; + if (chunk_idx == 0 && curr_chunk_len > 0) + { + h_merge_op[b] = 2; // only copy result + } + else if (curr_chunk_len > 0) + { + h_merge_op[b] = 1; // merge result + } + else + { + h_merge_op[b] = 0; // skip + } + } +#ifdef TRTLLM_MLA_CHUNKED_PREFILL_TEST_DBG + std::cout << "merge op: "; + this->showHostTensor(this->m_h_merge_op); + std::cout << "cu chunk lens: "; + this->showHostTensor(this->h_cu_chunk_lens); +#endif + } + + void PerformLoadChunkedKVRef(int chunk_idx) + { + using tensorrt_llm::runtime::bufferCast; + + auto* compressed_kv_output_ptr = bufferCast(*(this->h_compressed_kv_output_ref)); + auto* k_pe_output_ptr = bufferCast(*(this->h_k_pe_output_ref)); + auto* compressed_kv_cache_ptr = bufferCast(*(this->h_compressed_kv_cache_tensor)); + auto* offset_ptr = bufferCast(*(this->h_compressed_offset_tensor)); + auto* h_cu_chunk_lens_ptr = bufferCast(*(this->h_cu_chunk_lens)); + + tensorrt_llm::kernels::KVBlockArray kv_cache(this->mBatchSize, this->mMaxBlockPerSeq, this->mTokensPerBlock, + sizeof(DataType) * 1 * (this->mLoraSize + this->mRopeSize), 0, 0, 0, 0, compressed_kv_cache_ptr, nullptr, + reinterpret_cast(offset_ptr)); + this->PrepareChunkedLen(chunk_idx); + + loadChunkedKVKernelRef(compressed_kv_output_ptr, k_pe_output_ptr, kv_cache, this->mBatchSize, + h_cu_chunk_lens_ptr, this->mLoraSize, this->mRopeSize, this->mChunkSize, chunk_idx); + } + + void PreformLoadChunkedKV(int chunk_idx) + { + using tensorrt_llm::runtime::bufferCast; + + auto* compressed_kv_output_ptr = bufferCast(*(this->d_compressed_kv_output)); + auto* k_pe_output_ptr = bufferCast(*(this->d_k_pe_output)); + auto* compressed_kv_cache_ptr = bufferCast(*(this->d_compressed_kv_cache_tensor)); + auto* offset_ptr = bufferCast(*(this->d_compressed_offset_tensor)); + auto* d_cu_chunk_lens_ptr = bufferCast(*(this->d_cu_chunk_lens)); + + tensorrt_llm::kernels::KVBlockArray kv_cache(this->mBatchSize, this->mMaxBlockPerSeq, this->mTokensPerBlock, + sizeof(DataType) * 1 * (this->mLoraSize + this->mRopeSize), 0, 0, 0, 0, compressed_kv_cache_ptr, nullptr, + reinterpret_cast(offset_ptr)); + this->PrepareChunkedLen(chunk_idx); + // copy cu chunk lens to device + cudaMemcpy(this->d_cu_chunk_lens->data(), this->h_cu_chunk_lens->data(), + this->h_cu_chunk_lens->getSizeInBytes(), cudaMemcpyHostToDevice); + tensorrt_llm::kernels::invokeMLALoadChunkedKV(compressed_kv_output_ptr, k_pe_output_ptr, kv_cache, + this->mBatchSize, d_cu_chunk_lens_ptr, this->mLoraSize, this->mRopeSize, this->mChunkSize, chunk_idx, + mStream->get()); + cudaStreamSynchronize(this->mStream->get()); + // copy result back to host + cudaMemcpy(this->h_compressed_kv_output->data(), compressed_kv_output_ptr, + this->h_compressed_kv_output->getSizeInBytes(), cudaMemcpyDeviceToHost); + cudaMemcpy(this->h_k_pe_output->data(), k_pe_output_ptr, this->h_k_pe_output->getSizeInBytes(), + cudaMemcpyDeviceToHost); + sync_check_cuda_error(this->mStream->get()); + } + + void PerformSetChunkedKVRef() + { + using tensorrt_llm::runtime::bufferCast; + auto* kv_ptr = bufferCast(*(this->h_kv_tensor)); + auto* k_pe_ptr = bufferCast(*(this->h_k_pe_tensor)); + auto* kv_cache_ptr = bufferCast(*(this->h_kv_cache_tensor_ref)); + auto* cu_chunked_seq_lens_ptr = bufferCast(*(this->h_cu_chunk_lens)); + this->PrepareChunkedLen(0); + setChunkedKVCacheForMLAKernelRef(kv_cache_ptr, kv_ptr, k_pe_ptr, this->mBatchSize, cu_chunked_seq_lens_ptr, + this->mChunkSize, this->mNumHeads, this->mNopeSize, this->mRopeSize, this->mTokensPerBlock); + } + + void PerformSetChunkedKV() + { + using tensorrt_llm::runtime::bufferCast; + auto* kv_ptr = bufferCast(*(this->d_kv_tensor)); + auto* k_pe_ptr = bufferCast(*(this->d_k_pe_tensor)); + auto* kv_cache_ptr = bufferCast(*(this->d_kv_cache_tensor)); + auto* cu_chunked_seq_lens_ptr = bufferCast(*(this->d_cu_chunk_lens)); + this->PrepareChunkedLen(0); + // copy cu chunk lens to device + cudaMemcpy(this->d_cu_chunk_lens->data(), this->h_cu_chunk_lens->data(), + this->h_cu_chunk_lens->getSizeInBytes(), cudaMemcpyHostToDevice); + tensorrt_llm::kernels::invokeMLASetChunkedKV(kv_cache_ptr, kv_ptr, k_pe_ptr, this->mBatchSize, this->mChunkSize, + this->mNumHeads, this->mNopeSize, this->mRopeSize, cu_chunked_seq_lens_ptr, this->mTokensPerBlock, + mStream->get()); + cudaStreamSynchronize(this->mStream->get()); + // copy result back to host + cudaMemcpy(this->h_kv_cache_tensor->data(), kv_cache_ptr, this->h_kv_cache_tensor->getSizeInBytes(), + cudaMemcpyDeviceToHost); + sync_check_cuda_error(this->mStream->get()); + } +}; + +using MLATypes = ::testing::Types; + +TYPED_TEST_SUITE(MlaChunkedPrefillTest, MLATypes); + +TYPED_TEST(MlaChunkedPrefillTest, MlaChunkedPrefillDefault) +{ + using tensorrt_llm::runtime::bufferCast; + using DataType = typename TestFixture::DataType; + this->setDefaultParams(); + this->allocateBuffers(); + + sync_check_cuda_error(this->mStream->get()); + bool allEqual{true}; + + this->PerformNormalAttention(); + sync_check_cuda_error(this->mStream->get()); + + this->PerformMergedAttention(); + sync_check_cuda_error(this->mStream->get()); + + // check result + auto* output_ptr = bufferCast(*(this->m_h_output_tensor_accum)); + auto* output_ref_ptr = bufferCast(*(this->m_h_output_tensor_ref)); + for (int i = 0; i < this->m_h_output_tensor->getSize(); i++) + { + if (std::abs(static_cast(output_ptr[i]) - static_cast(output_ref_ptr[i])) + > getTolerance(output_ptr[i])) + { + std::cout << "Output mismatch at index " << i << ": " + << "expected " << static_cast(output_ref_ptr[i]) << ", got " + << static_cast(output_ptr[i]) << std::endl; + allEqual = false; + break; + } + } + ASSERT_TRUE(allEqual); +} + +TYPED_TEST(MlaChunkedPrefillTest, MlaChunkedPrefillCausalMask) +{ + using tensorrt_llm::runtime::bufferCast; + using DataType = typename TestFixture::DataType; + this->setDefaultParams(); + this->mIsCausalMask = true; + this->allocateBuffers(); + + sync_check_cuda_error(this->mStream->get()); + bool allEqual{true}; + + this->PerformNormalAttention(); + sync_check_cuda_error(this->mStream->get()); + + this->PerformMergedAttention(); + sync_check_cuda_error(this->mStream->get()); + + // check result + auto* output_ptr = bufferCast(*(this->m_h_output_tensor_accum)); + auto* output_ref_ptr = bufferCast(*(this->m_h_output_tensor_ref)); + for (int i = 0; i < this->m_h_output_tensor->getSize(); i++) + { + if (std::abs(static_cast(output_ptr[i]) - static_cast(output_ref_ptr[i])) + > getTolerance(output_ptr[i])) + { + std::cout << "Output mismatch at index " << i << ": " + << "expected " << static_cast(output_ref_ptr[i]) << ", got " + << static_cast(output_ptr[i]) << std::endl; + allEqual = false; + break; + } + } + ASSERT_TRUE(allEqual); +} + +TYPED_TEST(MlaChunkedPrefillTest, MlaChunkedLoad) +{ + using tensorrt_llm::runtime::bufferCast; + using DataType = typename TestFixture::DataType; + this->setDefaultParams(); + this->allocateBuffers(); + + sync_check_cuda_error(this->mStream->get()); + bool allEqual{true}; + + int const loop_count = (this->mMaxSeqLen + this->mChunkSize - 1) / this->mChunkSize; + for (int _ = 0; _ < loop_count - 1; _++) + { + this->PerformLoadChunkedKVRef(_); + sync_check_cuda_error(this->mStream->get()); + this->PreformLoadChunkedKV(_); + sync_check_cuda_error(this->mStream->get()); + + // check result + auto* compressed_kv_output_ptr = bufferCast(*(this->h_compressed_kv_output_ref)); + auto* compressed_kv_output_ref_ptr = bufferCast(*(this->h_compressed_kv_output)); + auto* k_pe_output_ptr = bufferCast(*(this->h_k_pe_output)); + auto* k_pe_output_ref_ptr = bufferCast(*(this->h_k_pe_output_ref)); + // check kv + for (int i = 0; i < this->h_compressed_kv_output->getSize(); i++) + { + if (std::abs(static_cast(compressed_kv_output_ptr[i]) + - static_cast(compressed_kv_output_ref_ptr[i])) + > getTolerance(compressed_kv_output_ptr[i])) + { + std::cout << "Compressed KV output mismatch at loop: " << _ << " index " << i << ": " + << "expected " << static_cast(compressed_kv_output_ref_ptr[i]) << ", got " + << static_cast(compressed_kv_output_ptr[i]) << std::endl; + allEqual = false; + break; + } + } + // check k_pe + for (int i = 0; i < this->h_k_pe_output->getSize(); i++) + { + if (std::abs(static_cast(k_pe_output_ptr[i]) - static_cast(k_pe_output_ref_ptr[i])) + > getTolerance(k_pe_output_ptr[i])) + { + std::cout << "kpe mismatch at loop: " << _ << " index " << i << ": " + << "expected " << static_cast(k_pe_output_ref_ptr[i]) << ", got " + << static_cast(k_pe_output_ptr[i]) << std::endl; + allEqual = false; + break; + } + } + } + ASSERT_TRUE(allEqual); +} + +TYPED_TEST(MlaChunkedPrefillTest, MlaChunkedSet) +{ + using tensorrt_llm::runtime::bufferCast; + using DataType = typename TestFixture::DataType; + this->setDefaultParams(); + this->allocateBuffers(); + + sync_check_cuda_error(this->mStream->get()); + bool allEqual{true}; + + this->PerformSetChunkedKVRef(); + sync_check_cuda_error(this->mStream->get()); + this->PerformSetChunkedKV(); + sync_check_cuda_error(this->mStream->get()); + + // check result + auto* kv_cache_ptr = bufferCast(*(this->h_kv_cache_tensor)); + auto* kv_cache_ptr_ref = bufferCast(*(this->h_kv_cache_tensor_ref)); + + for (int i = 0; i < this->h_kv_cache_tensor->getSize(); i++) + { + if (std::abs(static_cast(kv_cache_ptr[i]) - static_cast(kv_cache_ptr_ref[i])) + > getTolerance(kv_cache_ptr[i])) + { + std::cout << "KV cache mismatch at index " << i << ": " + << "expected " << static_cast(kv_cache_ptr_ref[i]) << ", got " + << static_cast(kv_cache_ptr[i]) << std::endl; + allEqual = false; + break; + } + } + ASSERT_TRUE(allEqual); +} diff --git a/tensorrt_llm/_torch/attention_backend/interface.py b/tensorrt_llm/_torch/attention_backend/interface.py index 03f41c6424b..20d487708f2 100644 --- a/tensorrt_llm/_torch/attention_backend/interface.py +++ b/tensorrt_llm/_torch/attention_backend/interface.py @@ -24,6 +24,8 @@ class AttentionRuntimeFeatures: chunked_prefill: bool = False cache_reuse: bool = False has_speculative_draft_tokens: bool = False + chunk_unit_size: int = 0 + normal_chunk_size: int = 0 # The type of requests in qkv passed to attention diff --git a/tensorrt_llm/_torch/attention_backend/trtllm.py b/tensorrt_llm/_torch/attention_backend/trtllm.py index 46585c4b2ed..8687192e358 100644 --- a/tensorrt_llm/_torch/attention_backend/trtllm.py +++ b/tensorrt_llm/_torch/attention_backend/trtllm.py @@ -168,6 +168,7 @@ def plan( mrope_config: Optional[dict] = None, mla_context_paged_kv: Optional[torch.Tensor] = None, mla_context_kv_cache_block_offsets: Optional[torch.Tensor] = None, + softmax_stats_tensor: Optional[torch.Tensor] = None, **kwargs, ): """ @@ -202,6 +203,7 @@ def plan( mrope_config (dict): The dictionary containing the mRope configuration. mla_context_paged_kv (torch.Tensor): The paged KV cache for MLA context, for kv cache reuse/chunked context. mla_context_kv_cache_block_offsets (torch.Tensor): The block offsets for the paged KV cache for MLA context, for kv cache reuse/chunked context. + softmax_stats_tensor (torch.Tensor): The tensor to store the softmax statistics (max/sum) """ self.layer_idx = layer_idx self.tokens_per_block = tokens_per_block @@ -237,6 +239,7 @@ def plan( self.block_ids_per_seq = block_ids_per_seq self.mla_context_paged_kv = mla_context_paged_kv self.mla_context_kv_cache_block_offsets = mla_context_kv_cache_block_offsets + self.softmax_stats_tensor = softmax_stats_tensor if max_sequence_length > self.rope_params.max_positions: self.rope_params.max_positions = max_sequence_length @@ -435,6 +438,7 @@ def run( self.mla_context_paged_kv, self.mla_context_kv_cache_block_offsets, self.attention_chunk_size, + self.softmax_stats_tensor, ) # reset the planned states (especially tensors) to avoid memory leak @@ -602,6 +606,16 @@ def __post_init__(self) -> None: device='cpu', pin_memory=True, ) + self.ctx_uncached_token_indptr = torch.zeros( + (self.max_num_requests + 1, ), + device='cuda', + dtype=torch.int64, + ) + self.host_ctx_uncached_token_indptr = torch.zeros_like( + self.ctx_uncached_token_indptr, + device='cpu', + pin_memory=True, + ) # context full seqlens include cached tokens and uncached tokens self.ctx_kv_indptr = torch.zeros( (self.max_num_requests + 1, ), @@ -702,17 +716,114 @@ def prepare_flash_mla(self) -> None: self.host_request_types_runtime = self.host_request_types[:self. num_seqs] + def pre_process_for_chunked_prefill( + self, + chunked_seq_len: torch.Tensor, + cu_chunked_seq_len: torch.Tensor, + merge_op_tensor: torch.Tensor, + chunked_loop_num: int, + ) -> None: + """ + Pre-process the MLA layer for chunked prefill. + This method is called before the forward pass to prepare the MLA layer for chunked prefill. + """ + num_contexts = self.num_contexts + chunk_size = self.runtime_features.normal_chunk_size + cached_kv_lens = torch.tensor( + self.kv_cache_params.num_cached_tokens_per_seq, + dtype=torch.int, + device='cpu', + ) + for loop_idx in range(chunked_loop_num): + cu_chunked_seq_len[loop_idx, 0] = 0 + used_chunk_seq_len = loop_idx * chunk_size + chunked_seq_len[loop_idx, :num_contexts] = torch.clamp( + cached_kv_lens[:num_contexts] - used_chunk_seq_len, + min=0, + max=chunk_size) + torch.cumsum(chunked_seq_len[loop_idx, :num_contexts], + dim=0, + dtype=torch.int64, + out=cu_chunked_seq_len[loop_idx, 1:num_contexts + 1]) + for s in range(num_contexts): + if loop_idx == 0 and chunked_seq_len[loop_idx, s] > 0: + merge_op_tensor[loop_idx, s] = 2 # copy only + elif chunked_seq_len[loop_idx, s] > 0: + merge_op_tensor[loop_idx, s] = 1 # merge + else: + merge_op_tensor[loop_idx, s] = 0 # skip + + # set merge op for last attn + for s in range(num_contexts): + if cached_kv_lens[s] == 0: + merge_op_tensor[chunked_loop_num, s] = 2 # copy only + else: + merge_op_tensor[chunked_loop_num, s] = 1 # merge + def prepare_paged_context_mla(self, cached_token_lens: torch.Tensor, kv_lens: torch.Tensor) -> None: if self.num_contexts > 0: self.num_ctx_cached_tokens = cached_token_lens[:self. num_contexts].sum( ).item() + self.max_ctx_cached_token_len = cached_token_lens[:self. + num_contexts].max( + ).item() self.max_ctx_kv_len = kv_lens[:self.num_contexts].max().item() self.max_ctx_seq_len = self.seq_lens[:self.num_contexts].max().item( ) + # determine the number of loop + # currently we assume that the chunk size is the same as the max_num_tokens + if self.runtime_features.chunked_prefill: + chunk_size = self.runtime_features.normal_chunk_size + self.chunked_loop_num = (self.max_ctx_cached_token_len + + chunk_size - 1) // chunk_size + self.chunked_seq_len = torch.empty( + (self.chunked_loop_num, self.num_seqs), + dtype=torch.int, + device='cuda', + ) + self.host_chunked_seq_len = torch.empty_like( + self.chunked_seq_len, + device='cpu', + pin_memory=True, + ) + self.cu_chunked_seq_len = torch.zeros( + (self.chunked_loop_num, self.num_contexts + 1), + dtype=torch.int64, + device='cuda', + ) + self.host_cu_chunked_seq_len = torch.zeros_like( + self.cu_chunked_seq_len, + device='cpu', + pin_memory=True, + ) + # For last chunk we use the uncached kv + self.merge_op_tensor = torch.empty( + (self.chunked_loop_num + 1, self.num_contexts), + dtype=torch.int64, + device='cuda', + ) + self.host_merge_op_tensor = torch.empty_like( + self.merge_op_tensor, + device='cpu', + pin_memory=True, + ) + + self.pre_process_for_chunked_prefill( + chunked_seq_len=self.host_chunked_seq_len, + cu_chunked_seq_len=self.host_cu_chunked_seq_len, + merge_op_tensor=self.host_merge_op_tensor, + chunked_loop_num=self.chunked_loop_num) + self.chunked_seq_len.copy_(self.host_chunked_seq_len, + non_blocking=True) + self.cu_chunked_seq_len.copy_(self.host_cu_chunked_seq_len, + non_blocking=True) + self.merge_op_tensor.copy_(self.host_merge_op_tensor, + non_blocking=True) else: self.num_ctx_cached_tokens = 0 + self.max_ctx_cached_token_len = 0 self.max_ctx_kv_len = 0 self.max_ctx_seq_len = 0 torch.cumsum(cached_token_lens[:self.num_contexts], @@ -723,7 +834,14 @@ def prepare_paged_context_mla(self, cached_token_lens: torch.Tensor, self.ctx_cached_token_indptr[:self.num_contexts + 1].copy_( self.host_ctx_cached_token_indptr[:self.num_contexts + 1], non_blocking=True) - + torch.cumsum( + self.seq_lens[:self.num_contexts], + dim=0, + dtype=torch.int64, + out=self.host_ctx_uncached_token_indptr[1:self.num_contexts + 1]) + self.ctx_uncached_token_indptr[:self.num_contexts + 1].copy_( + self.host_ctx_uncached_token_indptr[:self.num_contexts + 1], + non_blocking=True) torch.cumsum(kv_lens[:self.num_contexts], dim=0, dtype=torch.int64, @@ -835,6 +953,7 @@ def forward( attention_window_size: Optional[int] = None, mla_context_paged_kv: Optional[torch.Tensor] = None, mla_context_kv_cache_block_offsets: Optional[torch.Tensor] = None, + softmax_stats_tensor: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None, output_sf: Optional[torch.Tensor] = None, **kwargs, @@ -901,6 +1020,7 @@ def forward( mla_context_paged_kv=mla_context_paged_kv, mla_context_kv_cache_block_offsets= mla_context_kv_cache_block_offsets, + softmax_stats_tensor=softmax_stats_tensor, ) out_dtype = None if out_scale is not None: @@ -954,6 +1074,15 @@ def has_cached_kv_for_mla_context( and metadata.enable_paged_context_mla and metadata.num_ctx_cached_tokens > 0) + def is_chunked_prefill_for_mla_context( + self, + metadata: TrtllmAttentionMetadata, + ) -> bool: + return (self.is_mla_enable and metadata.kv_cache_manager is not None + and metadata.enable_paged_context_mla + and metadata.num_ctx_cached_tokens > 0 + and metadata.runtime_features.chunked_prefill) + def load_paged_kv_cache_for_mla( self, metadata: TrtllmAttentionMetadata, @@ -993,6 +1122,49 @@ def load_paged_kv_cache_for_mla( return compressed_kv, k_pe + def load_chunked_kv_cache_for_mla( + self, + metadata: TrtllmAttentionMetadata, + chunked_idx: int, + num_ctx_cached_tokens: int, + cu_chunked_seq_len: torch.Tensor, + out_dtype: torch.dtype, + ) -> torch.Tensor: + assert out_dtype in [torch.float16, torch.bfloat16, torch.float32] + assert self.is_mla_enable and self.mla_params is not None + assert metadata.kv_cache_manager is not None + + if metadata.max_ctx_cached_token_len == 0: + return torch.empty((0, metadata.kv_cache_manager.head_dim), + dtype=out_dtype, + device=cu_chunked_seq_len.device) + + sink_token_length = 0 + beam_width = 1 + + output_kv, output_k_pe = torch.ops.trtllm.load_chunked_kv_cache_for_mla( + out_dtype, + metadata.num_contexts, + num_ctx_cached_tokens, + cu_chunked_seq_len, + metadata.kv_cache_block_offsets, + metadata.kv_cache_manager.kv_cache_pool_pointers, + metadata.kv_cache_manager.kv_cache_pool_mapping, + self.kv_scale_orig_quant, + self.kv_scale_quant_orig, + self.layer_idx, + self.mla_params.kv_lora_rank, + self.mla_params.qk_rope_head_dim, + metadata.kv_cache_manager.tokens_per_block, + metadata.runtime_features.normal_chunk_size, + chunked_idx, + metadata.kv_cache_manager.max_seq_len, + sink_token_length, + beam_width, + self.wrapper.quant_mode, + ) + return output_kv, output_k_pe + def set_paged_kv_cache_for_mla( self, paged_kv: torch.Tensor, @@ -1029,6 +1201,50 @@ def set_paged_kv_cache_for_mla( assert paged_kv_offsets.shape == (num_contexts, 2, max_block_num) return paged_kv_offsets + def set_chunked_kv_cache_for_mla( + self, + paged_kv: torch.Tensor, + kv: torch.Tensor, + k_pe: torch.Tensor, + cu_chunked_seq_len: torch.Tensor, + cached: bool, + metadata: TrtllmAttentionMetadata, + ) -> torch.Tensor: + assert self.is_mla_enable and self.mla_params is not None + assert self.mla_params.qk_nope_head_dim == self.mla_params.v_head_dim + assert metadata.kv_cache_manager is not None + assert paged_kv.shape[0] == metadata.num_contexts + assert paged_kv.is_contiguous() + + kv = kv.contiguous() + k_pe = k_pe.contiguous() + + num_contexts = metadata.num_contexts + tokens_per_block = metadata.kv_cache_manager.tokens_per_block + if cached: + # this indptr is the fake. + cu_seq_len = cu_chunked_seq_len + max_seq_len = metadata.runtime_features.normal_chunk_size + else: + cu_seq_len = metadata.ctx_uncached_token_indptr + max_seq_len = metadata.max_ctx_seq_len + paged_kv_offsets = torch.ops.trtllm.set_chunked_kv_cache_for_mla( + paged_kv, + kv, + k_pe, + num_contexts, + cu_seq_len, + self.num_heads, + self.mla_params.qk_nope_head_dim, + self.mla_params.qk_rope_head_dim, + metadata.kv_cache_manager.tokens_per_block, + max_seq_len, + ) + + max_block_num = (max_seq_len + tokens_per_block - 1) // tokens_per_block + assert paged_kv_offsets.shape == (num_contexts, 2, max_block_num) + return paged_kv_offsets + def mla_rope_append_paged_kv_assign_q( self, q: torch.Tensor, @@ -1066,3 +1282,28 @@ def mla_rope_append_paged_kv_assign_q( beam_width, self.wrapper.quant_mode, ) + + def merge_attention_for_mla( + self, + merged_attn: torch.Tensor, + temp_attn: torch.Tensor, + softmax_stats: torch.Tensor, + temp_softmax_stats: torch.Tensor, + merge_op: torch.Tensor, + metadata: TrtllmAttentionMetadata, + ) -> None: + assert self.is_mla_enable and self.mla_params is not None + assert metadata.kv_cache_manager is not None + + torch.ops.trtllm.merge_chunked_attention_for_mla( + merged_attn, + temp_attn, + softmax_stats, + temp_softmax_stats, + metadata.num_contexts, + metadata.ctx_uncached_token_indptr, # cu_q_seq_len + metadata.max_ctx_seq_len, # max_q_seq_len + merge_op, + self.num_heads, + self.mla_params.v_head_dim, + ) diff --git a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py index 6e1fd681fd8..9b8b2f059b2 100644 --- a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py @@ -651,6 +651,7 @@ def attention( mla_context_paged_kv: Optional[torch.Tensor], mla_context_kv_cache_block_offsets: Optional[torch.Tensor], attention_chunk_size: Optional[int], + softmax_stats_tensor: Optional[torch.Tensor], ) -> List[torch.Tensor]: num_tokens = q.size(0) attention_input_type = (AttentionInputType(attention_input_type) @@ -697,7 +698,7 @@ def attention( q_lora_rank, kv_lora_rank, qk_nope_head_dim, qk_rope_head_dim, v_head_dim, mrope_rotary_cos_sin, mrope_position_deltas, mla_context_paged_kv, mla_context_kv_cache_block_offsets, - attention_chunk_size) + attention_chunk_size, softmax_stats_tensor) return output_act, output_sf diff --git a/tensorrt_llm/_torch/modules/attention.py b/tensorrt_llm/_torch/modules/attention.py index df03e74186f..f25b864707a 100644 --- a/tensorrt_llm/_torch/modules/attention.py +++ b/tensorrt_llm/_torch/modules/attention.py @@ -922,6 +922,166 @@ def forward_context_with_cached_kv( return attn_output + def forward_context_with_chunked_prefill( + self, + q: torch.Tensor, + compressed_kv: torch.Tensor, + latent_cache: torch. + Tensor, # compressed_kv + k_pe [context_tokens, 1, lora_size + rope_size] + attn_metadata: TrtllmAttentionMetadata, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + trtllm_attention = cast(TrtllmAttention, self.mha) + # apply RoPE, append compressed_kv + k_pe to paged kv cache and assign q_pe to q + trtllm_attention.mla_rope_append_paged_kv_assign_q( + q, latent_cache, attn_metadata) + + # determine the number of loop + # currently we assume that the chunk size is the same as the max_num_tokens + chunk_size = attn_metadata.runtime_features.normal_chunk_size + chunked_loop_num = attn_metadata.chunked_loop_num + + # [toal_token_q, num_heads, 2] -> [toal_token_q, num_heads] float2 + self.softmax_stats_tensor = torch.empty( + (attn_metadata.num_ctx_tokens, self.num_heads, 2), + dtype=torch.float, + device='cuda', + ) + self.temp_softmax_stats_tensor = torch.empty( + (attn_metadata.num_ctx_tokens, self.num_heads, 2), + dtype=torch.float, + device='cuda', + ) + if output is None: + attn_output = q.new_empty( + (q.size(0), self.num_heads * self.v_head_dim), dtype=q.dtype) + else: + attn_output = output + temp_attn_output = q.new_empty( + (q.size(0), self.num_heads * self.v_head_dim), dtype=q.dtype) + + # use fake cached_cu_seq_len for chunked loop + origin_kv_lens_cuda_runtime = attn_metadata.kv_lens_cuda_runtime + origin_kv_lens_runtime = attn_metadata.kv_lens_runtime + + for loop_idx in range(chunked_loop_num): + # {b, chunked_unit_size, h, kv_lora_rank + qk_rope_head_dim} zero padded + # fetch `loop_idx` chunk from kv cache + temp_cu_chunked_seq_len = attn_metadata.cu_chunked_seq_len[loop_idx] + total_ctx_chunked_tokens = attn_metadata.host_cu_chunked_seq_len[ + loop_idx, attn_metadata.num_contexts] + chunked_compressed_kv, chunked_k_pe = trtllm_attention.load_chunked_kv_cache_for_mla( + metadata=attn_metadata, + chunked_idx=loop_idx, + num_ctx_cached_tokens=total_ctx_chunked_tokens, + cu_chunked_seq_len=temp_cu_chunked_seq_len, + out_dtype=q.dtype) + + # up proj to uncompressed kv + # [tokens, 2, h, kv_dim], without rope_dim + chunked_kv = self.kv_b_proj(chunked_compressed_kv) + + # build full_kv + # full_kv {B, 2, chunk_size / tokens_per_block, h, tokens_per_block, kv_dim + rope_dim} + tokens_per_block = attn_metadata.kv_cache_manager.tokens_per_block + full_kv = torch.zeros([ + attn_metadata.num_contexts, 2, + (chunk_size + tokens_per_block - 1) // tokens_per_block, + self.num_heads, tokens_per_block, + max(self.qk_nope_head_dim + self.qk_rope_head_dim, + self.v_head_dim) + ], + dtype=q.dtype, + device=q.device) + mla_kv_cache_block_offsets = trtllm_attention.set_chunked_kv_cache_for_mla( + full_kv, + chunked_kv, + chunked_k_pe, + cu_chunked_seq_len=temp_cu_chunked_seq_len, + cached=True, + metadata=attn_metadata) + + # copy chunked_seq_len to replace kv_lens_runtime + attn_metadata.kv_lens_runtime = attn_metadata.host_chunked_seq_len[ + loop_idx] + attn_metadata.kv_lens_cuda_runtime = attn_metadata.chunked_seq_len[ + loop_idx] + out_scale = None + # do not apply mask for attention within loop + temp_attn_output = self.mha.forward( + q, + None, + None, + attn_metadata, + attention_input_type=AttentionInputType.context_only, + latent_cache=None, + out_scale=out_scale, + attention_mask=PredefinedAttentionMask.FULL, + mla_context_paged_kv=full_kv, + mla_context_kv_cache_block_offsets=mla_kv_cache_block_offsets, + softmax_stats_tensor=self.temp_softmax_stats_tensor, + output=temp_attn_output, + ) + # merge attn result + temp_merge_op = attn_metadata.merge_op_tensor[loop_idx] + trtllm_attention.merge_attention_for_mla( + attn_output, temp_attn_output, self.softmax_stats_tensor, + self.temp_softmax_stats_tensor, temp_merge_op, attn_metadata) + + # deal with the uncached kv + kv = self.kv_b_proj(compressed_kv) + _, k_pe = latent_cache.view([ + -1, self.kv_lora_rank + self.qk_rope_head_dim + ]).split([self.kv_lora_rank, self.qk_rope_head_dim], -1) + k_pe = k_pe.contiguous() + # final round of attention + + # out_scale = getattr(self.o_proj, "inv_input_scale", None) + out_scale = None # Currently we use BF16 MHA for context phase + + tokens_per_block = attn_metadata.kv_cache_manager.tokens_per_block + full_kv = torch.zeros([ + attn_metadata.num_contexts, 2, + (attn_metadata.max_ctx_seq_len + tokens_per_block - 1) // + tokens_per_block, self.num_heads, tokens_per_block, + max(self.qk_nope_head_dim + self.qk_rope_head_dim, self.v_head_dim) + ], + dtype=q.dtype, + device=q.device) + mla_kv_cache_block_offsets = trtllm_attention.set_chunked_kv_cache_for_mla( + full_kv, + kv, + k_pe, + cu_chunked_seq_len=None, + cached=False, + metadata=attn_metadata) + # copy q_lens to replace kv_lens_runtime + attn_metadata.kv_lens_runtime = attn_metadata.prompt_lens_cpu_runtime + attn_metadata.kv_lens_cuda_runtime = attn_metadata.prompt_lens_cuda_runtime + temp_attn_output = self.mha.forward( + q, + None, + None, + attn_metadata, + attention_input_type=AttentionInputType.context_only, + latent_cache=None, + out_scale=out_scale, + mla_context_paged_kv=full_kv, + mla_context_kv_cache_block_offsets=mla_kv_cache_block_offsets, + softmax_stats_tensor=self.temp_softmax_stats_tensor, + output=temp_attn_output, + ) + temp_merge_op = attn_metadata.merge_op_tensor[chunked_loop_num] + trtllm_attention.merge_attention_for_mla(attn_output, temp_attn_output, + self.softmax_stats_tensor, + self.temp_softmax_stats_tensor, + temp_merge_op, attn_metadata) + # copy back kv_lens_runtime and kv_lens_cuda_runtime + attn_metadata.kv_lens_runtime = origin_kv_lens_runtime + attn_metadata.kv_lens_cuda_runtime = origin_kv_lens_cuda_runtime + + return attn_output + def forward_context( self, q: torch.Tensor, @@ -934,7 +1094,11 @@ def forward_context( if isinstance(self.mha, TrtllmAttention): assert isinstance(attn_metadata, TrtllmAttentionMetadata) trtllm_attention = cast(TrtllmAttention, self.mha) - if trtllm_attention.has_cached_kv_for_mla_context(attn_metadata): + if trtllm_attention.is_chunked_prefill_for_mla_context( + attn_metadata): + return self.forward_context_with_chunked_prefill( + q, compressed_kv, latent_cache, attn_metadata, output) + elif trtllm_attention.has_cached_kv_for_mla_context(attn_metadata): return self.forward_context_with_cached_kv( q, latent_cache, attn_metadata, output) return self.forward_context_default(q, compressed_kv, k_pe, diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index e877117ec85..c58b4ca266e 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -783,8 +783,9 @@ def disable_optimization(backend: Backend): def _set_up_attn_metadata(self, kv_cache_manager: KVCacheManager): enable_paged_context_mla = is_mla( - self.model.model_config.pretrained_config - ) and self.attn_runtime_features.cache_reuse + self.model.model_config.pretrained_config) and ( + self.attn_runtime_features.cache_reuse + or self.attn_runtime_features.chunked_prefill) if kv_cache_manager is None: return self.attn_backend.Metadata( max_num_requests=self.batch_size, diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py index 354981680ed..1f50fee7611 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py @@ -195,11 +195,14 @@ def create_py_executor( has_draft_model_engine = spec_config.spec_dec_mode.has_draft_model() has_ngram_drafter = isinstance(spec_config, NGramConfig) + # chunk_unit_size may be changed to 64 when using flash mla attn_runtime_features = AttentionRuntimeFeatures( chunked_prefill=executor_config.enable_chunked_context, cache_reuse=executor_config.kv_cache_config.enable_block_reuse, has_speculative_draft_tokens=has_draft_model_engine or has_ngram_drafter, + chunk_unit_size=executor_config.tokens_per_block, + normal_chunk_size=executor_config.max_num_tokens, ) logger.info("ATTENTION RUNTIME FEATURES: ", attn_runtime_features) @@ -264,18 +267,6 @@ def create_py_executor( executor_config.max_num_tokens = model_engine.max_num_tokens spec_config = model_engine.spec_config - if executor_config.enable_chunked_context: - chunk_unit_size = executor_config.tokens_per_block - chunking_policy = ( - executor_config.scheduler_config.context_chunking_policy - if executor_config.scheduler_config.context_chunking_policy - is not None else ContextChunkingPolicy.FIRST_COME_FIRST_SERVED) - assert chunk_unit_size is not None, "chunk_unit_size must be set" - ctx_chunk_config = ContextChunkingConfig(chunking_policy, - chunk_unit_size) - else: - ctx_chunk_config = None - config = model_engine.model.model_config.pretrained_config if is_mla(config): if model_engine.model.model_config.enable_flash_mla: @@ -301,7 +292,17 @@ def create_py_executor( ) executor_config.kv_cache_config.enable_block_reuse = False - executor_config.enable_chunked_context = False + if executor_config.enable_chunked_context: + chunk_unit_size = executor_config.tokens_per_block + chunking_policy = ( + executor_config.scheduler_config.context_chunking_policy + if executor_config.scheduler_config.context_chunking_policy + is not None else ContextChunkingPolicy.FIRST_COME_FIRST_SERVED) + assert chunk_unit_size is not None, "chunk_unit_size must be set" + ctx_chunk_config = ContextChunkingConfig(chunking_policy, + chunk_unit_size) + else: + ctx_chunk_config = None with mem_monitor.observe_creation_stage(_ExecutorCreationStage.SAMPLER): sampler = instantiate_sampler(model_engine, executor_config, diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 8f6e546c40a..f9bb7136d46 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -1059,6 +1059,84 @@ def test_no_kv_cache_reuse(self, quant_dtype, mtp_nextn, fp8kv, task = GSM8K(self.MODEL_NAME) task.evaluate(llm) + @parametrize_with_ids("fp8kv,attention_dp,cuda_graph,overlap_scheduler", + [(False, False, False, False), + (False, True, False, False), + (False, False, True, False), + (False, False, False, True), + (False, True, True, True), (True, True, True, True)]) + @parametrize_with_ids("mtp_nextn", [0]) + @parametrize_with_ids("kv_cache_reuse", [True, False]) + @parametrize_with_ids( + "quant_dtype", + [ + pytest.param("none", marks=skip_pre_blackwell), + # pytest.param("fp8", marks=skip_pre_hopper), + # pytest.param("nvfp4", marks=skip_pre_blackwell) + ]) + # currently, chunked prefill is not supported for fp8 and nvfp4 + def test_chunked_prefill(self, quant_dtype, mtp_nextn, kv_cache_reuse, + fp8kv, attention_dp, cuda_graph, + overlap_scheduler): + if quant_dtype == "nvfp4" and mtp_nextn > 0: + pytest.skip("MTP is not supported for NVFP4") + if fp8kv: + pytest.skip("Currently do not support fp8") + + model_path = self.MODEL_PATH + if quant_dtype == "fp8": + model_path = f"{llm_models_root()}/DeepSeek-V3-Lite/fp8" + elif quant_dtype == "nvfp4": + model_path = f"{llm_models_root()}/DeepSeek-V3-Lite/nvfp4_moe_only" + + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.6, + enable_block_reuse=kv_cache_reuse) + pytorch_config = dict( + disable_overlap_scheduler=not overlap_scheduler, + use_cuda_graph=cuda_graph, + ) + mtp_config = None + if mtp_nextn > 0: + mtp_config = MTPDecodingConfig(num_nextn_predict_layers=mtp_nextn) + + if quant_dtype == "none": + assert not fp8kv + quant_config = None + else: + quant_config = QuantConfig() + if quant_dtype == "fp8": + quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES + elif quant_dtype == "nvfp4": + quant_config.quant_algo = QuantAlgo.NVFP4 + if fp8kv: + quant_config.kv_cache_quant_algo = QuantAlgo.FP8 + pytorch_config["kv_cache_dtype"] = "fp8" + + llm = LLM(model_path, + kv_cache_config=kv_cache_config, + enable_chunked_prefill=True, + max_num_tokens=512, + **pytorch_config, + quant_config=quant_config, + enable_attention_dp=attention_dp, + speculative_config=mtp_config) + + if quant_dtype == "fp8": + assert llm.args.quant_config.quant_algo == QuantAlgo.FP8_BLOCK_SCALES + elif quant_dtype == "nvfp4": + assert llm.args.quant_config.quant_algo == QuantAlgo.NVFP4 + + if fp8kv: + assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.FP8 + + with llm: + # No need to run MMLU for fp8kv + if not fp8kv: + task = MMLU(self.MODEL_NAME) + task.evaluate(llm) + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm) + @pytest.mark.timeout(7200) @pytest.mark.skip_less_device_memory(80000) diff --git a/tests/integration/test_lists/test-db/l0_b200.yml b/tests/integration/test_lists/test-db/l0_b200.yml index 9053617935e..0c0cc685069 100644 --- a/tests/integration/test_lists/test-db/l0_b200.yml +++ b/tests/integration/test_lists/test-db/l0_b200.yml @@ -49,6 +49,8 @@ l0_b200: - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=TRTLLM-mtp_nextn=2-fp8kv=True-attention_dp=False-cuda_graph=True-overlap_scheduler=True-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_no_kv_cache_reuse[quant_dtype=none-mtp_nextn=2-fp8kv=False-attention_dp=True-cuda_graph=True-overlap_scheduler=True] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_no_kv_cache_reuse[quant_dtype=nvfp4-mtp_nextn=0-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_chunked_prefill[quant_dtype=none-kv_cache_reuse=True-mtp_nextn=0-fp8kv=False-attention_dp=True-cuda_graph=True-overlap_scheduler=True] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_chunked_prefill[quant_dtype=none-kv_cache_reuse=False-mtp_nextn=0-fp8kv=False-attention_dp=True-cuda_graph=True-overlap_scheduler=True] - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[latency_moe_cutlass] - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[latency_moe_trtllm] - test_e2e.py::test_ptp_quickstart_advanced[Llama3.1-8B-NVFP4-nvfp4-quantized/Meta-Llama-3.1-8B]