Skip to content

Commit e05ae76

Browse files
authored
Merge branch 'main' into user/pengyunl/disagg_check
2 parents d6dd04d + ac23f4a commit e05ae76

File tree

79 files changed

+2516
-754
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

79 files changed

+2516
-754
lines changed

.github/pull_request_template.md

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,21 @@
33
<!--
44
Please write the PR title by following this template:
55
6-
[JIRA ticket/NVBugs ID/GitHub issue][fix/feat/doc/infra/...] \<summary of this PR\>
6+
**[JIRA ticket/NVBugs ID/GitHub issue/None][type] Summary**
77
8-
For example, assume I have a PR to support a new feature about cache manager for JIRA ticket TRTLLM-1000, it would be like:
8+
Valid ticket formats:
9+
- JIRA ticket: [TRTLLM-1234] or [FOOBAR-123] for other FOOBAR project
10+
- NVBugs ID: [https://nvbugs/1234567]
11+
- GitHub issue: [#1234]
12+
- No ticket: [None]
913
10-
[TRTLLM-1000][feat] Support a new feature about cache manager
14+
Valid types (lowercase): [fix], [feat], [doc], [infra], [chore], etc.
1115
12-
Or I have a PR to fix a Llama3 accuracy issue:
13-
14-
[https://nvbugs/1234567][fix] Fix Llama3 accuracy issue
16+
Examples:
17+
- [TRTLLM-1234][feat] Add new feature
18+
- [https://nvbugs/1234567][fix] Fix some bugs
19+
- [#1234][doc] Update documentation
20+
- [None][chore] Minor clean-up
1521
-->
1622

1723
## Description

.github/workflows/pr-check.yml

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
name: PR Checks
17+
18+
on:
19+
pull_request:
20+
types: [opened, edited, synchronize, reopened]
21+
22+
jobs:
23+
check-pr-title:
24+
name: Check PR Title Format
25+
runs-on: ubuntu-latest
26+
steps:
27+
- name: Validate PR Title Format
28+
id: check-pr-title
29+
uses: agenthunt/[email protected]
30+
continue-on-error: true
31+
with:
32+
pr-title-regex: "^(\\[(None|[A-Z0-9]+-[0-9]+|#[0-9]+|https:\\/\\/nvbugs\\/[0-9]+)\\])(\\[[a-z0-9]+\\]) (([^ ].*)?[^ ])$"
33+
pr-body-regex: ""
34+
35+
- name: PR Title Format Guide
36+
if: steps.check-pr-title.outcome == 'failure'
37+
run: |
38+
echo "::error::PR title format check failed."
39+
echo "Expected PR title format:"
40+
echo " [JIRA ticket/NVBugs ID/GitHub issue/None][type] Summary"
41+
echo ""
42+
echo "Valid ticket formats:"
43+
echo " - JIRA ticket: [TRTLLM-1234] or [FOOBAR-123] for other FOOBAR project"
44+
echo " - NVBugs ID: [https://nvbugs/1234567]"
45+
echo " - GitHub issue: [#1234]"
46+
echo " - No ticket: [None]"
47+
echo ""
48+
echo "Valid types (lowercase): [fix], [feat], [doc], [infra], [chore], etc."
49+
echo ""
50+
echo "Examples:"
51+
echo " - [TRTLLM-1234][feat] Add new feature"
52+
echo " - [https://nvbugs/1234567][fix] Fix some bugs"
53+
echo " - [#1234][doc] Update documentation"
54+
echo " - [None][chore] Minor clean-up"
55+
exit 1

cpp/include/tensorrt_llm/runtime/gptDecoder.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
#include "tensorrt_llm/runtime/bufferManager.h"
2121
#include "tensorrt_llm/runtime/decodingInput.h"
2222
#include "tensorrt_llm/runtime/decodingOutput.h"
23-
#include "tensorrt_llm/runtime/request.h"
2423
#include "tensorrt_llm/runtime/samplingConfig.h"
2524

2625
#include <NvInferRuntime.h>

cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -477,9 +477,9 @@ std::map<SizeType32, float> BlockManager::calculateWindowSizeToShare(
477477
windowSizeToContribution[windowSize] = cacheSizeWeight;
478478
}
479479

480-
for (auto const& [windowSize, layers] : windowSizeToLayers)
480+
for (auto const& [windowSize, _] : windowSizeToLayers)
481481
{
482-
windowSizeToContribution.at(windowSize) *= windowSize * layers.size();
482+
windowSizeToContribution.at(windowSize) *= windowSize;
483483
}
484484
auto const windowSizesTotalSum = std::accumulate(windowSizeToContribution.begin(), windowSizeToContribution.end(),
485485
0.0, [](auto sum, auto const& windowSize) { return sum + windowSize.second; });

cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplCommon.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ XQAKernelRuntimeHashKey getRuntimeHashKeyFromXQAParams(XQAParams const& xqaParam
5555
// precompiled XQA does not use is_fp8_output as hashing key
5656
return {xqaParams.kv_cache_data_type, head_size, beam_width, kernel_num_q_heads_over_kv, kernel_m_tilesize,
5757
xqaParams.paged_kv_cache ? static_cast<unsigned int>(xqaParams.tokens_per_block) : 0, xqaParams.paged_kv_cache,
58-
xqaParams.multi_query_tokens, isXqaJit ? xqaParams.is_fp8_output : false};
58+
xqaParams.multi_query_tokens, isXqaJit ? xqaParams.is_fp8_output : false,
59+
isXqaJit ? std::optional(xqaParams.position_embedding_type) : std::nullopt};
5960
}
6061

6162
} // namespace tensorrt_llm::kernels

cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplCommon.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,15 @@ struct XQAKernelRuntimeHashKey
6767
bool paged_kv_cache;
6868
bool multi_query_tokens;
6969
bool is_fp8_output;
70+
std::optional<PositionEmbeddingType> position_embedding_type;
7071

7172
bool operator==(XQAKernelRuntimeHashKey const& other) const
7273
{
7374
return kv_data_type == other.kv_data_type && head_size == other.head_size
7475
&& num_q_heads_per_kv == other.num_q_heads_per_kv && beam_size == other.beam_size
7576
&& multi_query_tokens == other.multi_query_tokens && m_tilesize == other.m_tilesize
7677
&& tokens_per_page == other.tokens_per_page && paged_kv_cache == other.paged_kv_cache
77-
&& is_fp8_output == other.is_fp8_output;
78+
&& is_fp8_output == other.is_fp8_output && position_embedding_type == other.position_embedding_type;
7879
}
7980
};
8081

@@ -103,6 +104,8 @@ struct XQAKernelRuntimeHasher
103104
key ^= s.multi_query_tokens;
104105
key <<= 1;
105106
key ^= s.is_fp8_output;
107+
key <<= 8;
108+
key ^= static_cast<int8_t>(s.position_embedding_type.value_or(static_cast<PositionEmbeddingType>(-1)));
106109
return key;
107110
}
108111
};

cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/decoderXQAImplJIT.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ using ::tensorrt_llm::kernels::XQAKernelMetaInfo;
3737
XQAKernelRuntimeHashKey getRuntimeHashKeyFromKernelMeta(XQAKernelMetaInfo const& kernelMeta)
3838
{
3939
return {kernelMeta.mKVDataType, kernelMeta.mHeadDim, kernelMeta.mBeamWidth, kernelMeta.mNumQHeadsOverKV,
40-
kernelMeta.mMTileSize, kernelMeta.mTokensPerPage, kernelMeta.mPagedKVCache, kernelMeta.mMultiQueryTokens,
41-
0 /* xqa jit param is_fp8_output */};
40+
kernelMeta.mMTileSize, kernelMeta.mTokensPerPage, kernelMeta.mPagedKVCache, kernelMeta.mMultiQueryTokens, false,
41+
std::nullopt};
4242
}
4343

4444
} // anonymous namespace

cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplPrecompiled.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ class XQAKernelList
9797
}
9898
XQAKernelRuntimeHashKey hash_key{kernelMeta.mKVDataType, kernelMeta.mHeadDim, kernelMeta.mBeamWidth,
9999
kernelMeta.mNumQHeadsOverKV, kernelMeta.mMTileSize, kernelMeta.mTokensPerPage, kernelMeta.mPagedKVCache,
100-
kernelMeta.mMultiQueryTokens, 0 /* xqa jit param is_fp8_output */};
100+
kernelMeta.mMultiQueryTokens, false, std::nullopt};
101101

102102
mFunctions.insert(std::make_pair(hash_key, funcInfo));
103103
}
@@ -128,7 +128,8 @@ class XQAKernelList
128128
XQAKernelRuntimeHashKey hash_key
129129
= {xqaParams.kv_cache_data_type, head_size, beam_width, kernel_num_q_heads_over_kv, m_tilesize,
130130
xqaParams.paged_kv_cache ? static_cast<unsigned int>(xqaParams.tokens_per_block) : 0,
131-
xqaParams.paged_kv_cache, xqaParams.multi_query_tokens, 0 /* xqa jit param is_fp8_output */};
131+
xqaParams.paged_kv_cache, xqaParams.multi_query_tokens, 0, /* xqa jit param is_fp8_output */
132+
std::nullopt};
132133
auto const findIter = mFunctions.find(hash_key);
133134
return findIter != mFunctions.end();
134135
}

cpp/tensorrt_llm/nanobind/batch_manager/algorithms.cpp

Lines changed: 3 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,13 @@
2020
#include "tensorrt_llm/batch_manager/assignReqSeqSlots.h"
2121
#include "tensorrt_llm/batch_manager/capacityScheduler.h"
2222
#include "tensorrt_llm/batch_manager/createNewDecoderRequests.h"
23-
#include "tensorrt_llm/batch_manager/handleContextLogits.h"
24-
#include "tensorrt_llm/batch_manager/handleGenerationLogits.h"
2523
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
2624
#include "tensorrt_llm/batch_manager/llmRequest.h"
2725
#include "tensorrt_llm/batch_manager/logitsPostProcessor.h"
28-
#include "tensorrt_llm/batch_manager/makeDecodingBatchInputOutput.h"
2926
#include "tensorrt_llm/batch_manager/medusaBuffers.h"
3027
#include "tensorrt_llm/batch_manager/microBatchScheduler.h"
3128
#include "tensorrt_llm/batch_manager/pauseRequests.h"
3229
#include "tensorrt_llm/batch_manager/peftCacheManager.h"
33-
#include "tensorrt_llm/batch_manager/runtimeBuffers.h"
34-
#include "tensorrt_llm/batch_manager/updateDecoderBuffers.h"
3530
#include "tensorrt_llm/nanobind/common/customCasters.h"
3631
#include "tensorrt_llm/runtime/decoderState.h"
3732
#include "tensorrt_llm/runtime/torch.h"
@@ -94,48 +89,6 @@ void tensorrt_llm::nanobind::batch_manager::algorithms::initBindings(nb::module_
9489
nb::arg("generation_requests"), nb::arg("model_config"), nb::arg("cross_kv_cache_manager") = std::nullopt)
9590
.def("name", [](AllocateKvCache const&) { return AllocateKvCache::name; });
9691

97-
nb::class_<HandleContextLogits>(m, HandleContextLogits::name)
98-
.def(nb::init<>())
99-
.def(
100-
"__call__",
101-
[](HandleContextLogits const& self, DecoderInputBuffers& inputBuffers, RequestVector const& contextRequests,
102-
at::Tensor const& logits, std::vector<tr::SizeType32> const& numContextLogitsVec,
103-
tr::ModelConfig const& modelConfig, tr::BufferManager const& manager,
104-
OptionalRef<MedusaBuffers> medusaBuffers = std::nullopt)
105-
{
106-
return self(inputBuffers, contextRequests, tr::TorchView::of(logits), numContextLogitsVec, modelConfig,
107-
manager, medusaBuffers);
108-
},
109-
nb::arg("decoder_input_buffers"), nb::arg("context_requests"), nb::arg("logits"),
110-
nb::arg("num_context_logits"), nb::arg("model_config"), nb::arg("buffer_manager"),
111-
nb::arg("medusa_buffers") = std::nullopt)
112-
.def("name", [](HandleContextLogits const&) { return HandleContextLogits::name; });
113-
114-
nb::class_<HandleGenerationLogits>(m, HandleGenerationLogits::name)
115-
.def(nb::init<>())
116-
.def(
117-
"__call__",
118-
[](HandleGenerationLogits const& self, DecoderInputBuffers& inputBuffers,
119-
RequestVector const& generationRequests, at::Tensor const& logits, tr::SizeType32 logitsIndex,
120-
tr::ModelConfig const& modelConfig, tr::BufferManager const& manager,
121-
OptionalRef<RuntimeBuffers> genRuntimeBuffers = std::nullopt,
122-
OptionalRef<MedusaBuffers> medusaBuffers = std::nullopt)
123-
{
124-
self(inputBuffers, generationRequests, tr::TorchView::of(logits), logitsIndex, modelConfig, manager,
125-
genRuntimeBuffers, medusaBuffers);
126-
},
127-
nb::arg("decoder_input_buffers"), nb::arg("generation_requests"), nb::arg("logits"),
128-
nb::arg("logits_index"), nb::arg("model_config"), nb::arg("buffer_manager"),
129-
nb::arg("gen_runtime_buffers") = std::nullopt, nb::arg("medusa_buffers") = std::nullopt)
130-
.def("name", [](HandleGenerationLogits const&) { return HandleGenerationLogits::name; });
131-
132-
nb::class_<MakeDecodingBatchInputOutput>(m, MakeDecodingBatchInputOutput::name)
133-
.def(nb::init<>())
134-
.def("__call__", &MakeDecodingBatchInputOutput::operator(), nb::arg("decoder_input_buffers"),
135-
nb::arg("decoder_state"), nb::arg("model_config"), nb::arg("max_num_sequences"),
136-
nb::arg("fused_runtime_buffers") = std::nullopt)
137-
.def("name", [](MakeDecodingBatchInputOutput const&) { return MakeDecodingBatchInputOutput::name; });
138-
13992
nb::class_<LogitsPostProcessor>(m, LogitsPostProcessor::name)
14093
.def(nb::init<>())
14194
.def("__call__", &LogitsPostProcessor::operator(), nb::arg("decoder_input_buffers"),
@@ -154,8 +107,9 @@ void tensorrt_llm::nanobind::batch_manager::algorithms::initBindings(nb::module_
154107
DecoderInputBuffers& inputBuffers, runtime::decoder::DecoderState& decoderState,
155108
tensorrt_llm::runtime::CudaStream const& runtimeStream,
156109
tensorrt_llm::runtime::CudaStream const& decoderStream, SizeType32 maxSequenceLength,
157-
SizeType32 beamWidth, OptionalRef<MedusaBuffers const> medusaBuffers = std::nullopt)
110+
SizeType32 beamWidth)
158111
{
112+
OptionalRef<MedusaBuffers const> medusaBuffers = std::nullopt;
159113
auto [batchSlots, samplingConfigs, lookaheadPrompt, lookaheadAlgoConfigs] = self(modelConfig,
160114
worldConfig, decodingConfig, contextRequests, bufferManager, logitsType, inputBuffers, decoderState,
161115
runtimeStream, decoderStream, maxSequenceLength, beamWidth, medusaBuffers);
@@ -166,13 +120,6 @@ void tensorrt_llm::nanobind::batch_manager::algorithms::initBindings(nb::module_
166120
nb::arg("model_config"), nb::arg("world_config"), nb::arg("decoding_config"), nb::arg("context_requests"),
167121
nb::arg("buffer_manager"), nb::arg("logits_type"), nb::arg("decoder_input_buffers"),
168122
nb::arg("decoder_state"), nb::arg("runtime_stream"), nb::arg("decoder_stream"),
169-
nb::arg("max_sequence_length"), nb::arg("beam_width"), nb::arg("medusa_buffers") = std::nullopt)
123+
nb::arg("max_sequence_length"), nb::arg("beam_width"))
170124
.def("name", [](CreateNewDecoderRequests const&) { return CreateNewDecoderRequests::name; });
171-
172-
nb::class_<UpdateDecoderBuffers>(m, UpdateDecoderBuffers::name)
173-
.def(nb::init<>())
174-
.def("__call__", &UpdateDecoderBuffers::operator(), nb::arg("model_config"), nb::arg("decoder_output_buffers"),
175-
nb::arg("copy_buffer_manager"), nb::arg("decoder_state"), nb::arg("return_log_probs"),
176-
nb::arg("decoder_finish_event"))
177-
.def("name", [](UpdateDecoderBuffers const&) { return UpdateDecoderBuffers::name; });
178125
}

cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,9 @@
2020

2121
#include "tensorrt_llm/batch_manager/common.h"
2222
#include "tensorrt_llm/batch_manager/decoderBuffers.h"
23-
#include "tensorrt_llm/batch_manager/medusaBuffers.h"
2423
#include "tensorrt_llm/batch_manager/microBatchScheduler.h"
2524
#include "tensorrt_llm/batch_manager/peftCacheManager.h"
2625
#include "tensorrt_llm/batch_manager/rnnStateManager.h"
27-
#include "tensorrt_llm/batch_manager/runtimeBuffers.h"
2826
#include "tensorrt_llm/batch_manager/sequenceSlotManager.h"
2927
#include "tensorrt_llm/nanobind/common/bindTypes.h"
3028
#include "tensorrt_llm/runtime/gptDecoderBatched.h"
@@ -419,13 +417,6 @@ void initBindings(nb::module_& m)
419417
.def_rw("log_probs_host", &tb::SlotDecoderBuffers::logProbsHost)
420418
.def_rw("finish_reasons_host", &tb::SlotDecoderBuffers::finishReasonsHost);
421419

422-
nb::class_<tb::MedusaBuffers>(m, "MedusaBuffers")
423-
.def(nb::init<runtime::SizeType32, runtime::SizeType32, runtime::BufferManager const&,
424-
runtime::ModelConfig const&, runtime::WorldConfig const&, executor::DecodingConfig const&,
425-
runtime::TllmRuntime const&>(),
426-
nb::arg("max_beam_width"), nb::arg("max_seq_len"), nb::arg("buffer_manager"), nb::arg("model_config"),
427-
nb::arg("world_config"), nb::arg("decoding_config"), nb::arg("runtime"));
428-
429420
m.def(
430421
"add_new_tokens_to_requests",
431422
[](std::vector<std::shared_ptr<tb::LlmRequest>>& requests,

0 commit comments

Comments
 (0)