Skip to content

Commit d718f78

Browse files
authored
Merge branch 'main' into dev-fxiong-drafter
2 parents c3792ef + 9518e14 commit d718f78

File tree

82 files changed

+703
-335
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

82 files changed

+703
-335
lines changed

.coderabbit.yaml

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
# yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json
17+
language: "en-US"
18+
reviews:
19+
auto_review:
20+
drafts: true
21+
base_branches: ["main", "release/.+"]
22+
commit_status: false

.github/pull_request_template.md

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,24 @@
1+
@coderabbitai summary
12

2-
# PR title
3-
4-
Please write the PR title by following template:
3+
<!--
4+
Please write the PR title by following this template:
55
6-
[JIRA ticket link/nvbug link/github issue link][fix/feat/doc/infra/...] \<summary of this PR\>
6+
[JIRA ticket/NVBugs ID/GitHub issue][fix/feat/doc/infra/...] \<summary of this PR\>
77
8-
For example, assume I have a PR hope to support a new feature about cache manager of Jira TRTLLM-1000 ticket, it would be like
8+
For example, assume I have a PR to support a new feature about cache manager for JIRA ticket TRTLLM-1000, it would be like:
99
1010
[TRTLLM-1000][feat] Support a new feature about cache manager
1111
12+
Or I have a PR to fix a Llama3 accuracy issue:
13+
14+
[https://nvbugs/1234567][fix] Fix Llama3 accuracy issue
15+
-->
16+
1217
## Description
1318

19+
<!--
1420
Please explain the issue and the solution in short.
21+
-->
1522

1623
## Test Coverage
1724

benchmarks/cpp/disaggServerBenchmark.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -636,6 +636,8 @@ class DisaggExecutorServer
636636
: texec::DecodingMode::Auto(),
637637
benchmarkParams.executorLookaheadConfig, benchmarkParams.medusaChoices));
638638
executorConfig.setExtendedRuntimePerfKnobConfig(extendedRuntimePerfKnobConfig);
639+
executorConfig.setCacheTransceiverConfig(
640+
texec::CacheTransceiverConfig(texec::CacheTransceiverConfig::BackendType::DEFAULT));
639641
constexpr int maxIterationsForRequestStats = 1000;
640642
if (mEnableCollectKvCacheTransferTime)
641643
{

cpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -70,28 +70,20 @@ class BaseCacheTransceiver
7070
class CacheTransceiver : public BaseCacheTransceiver
7171
{
7272
public:
73-
enum class CommType : std::uint8_t
74-
{
75-
UNKNOWN = 0,
76-
MPI = 1,
77-
UCX = 2,
78-
NIXL = 3
79-
};
80-
81-
CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheManager, CommType commType,
73+
CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheManager,
8274
executor::kv_cache::CacheState::ModelConfig const& cacheStateModelCfg, runtime::WorldConfig const& worldConfig,
8375
nvinfer1::DataType dataType,
8476
executor::kv_cache::CacheState::AttentionType attentionType
8577
= executor::kv_cache::CacheState::AttentionType::kDEFAULT,
8678
std::optional<executor::CacheTransceiverConfig> cacheTransceiverConfig = std::nullopt);
8779

88-
CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheManager, CommType commType,
89-
std::vector<SizeType32> numKvHeadsPerLayer, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
90-
runtime::WorldConfig const& worldConfig, nvinfer1::DataType dataType,
80+
CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheManager, std::vector<SizeType32> numKvHeadsPerLayer,
81+
SizeType32 sizePerHead, SizeType32 tokensPerBlock, runtime::WorldConfig const& worldConfig,
82+
nvinfer1::DataType dataType,
9183
executor::kv_cache::CacheState::AttentionType attentionType
9284
= executor::kv_cache::CacheState::AttentionType::kDEFAULT,
9385
std::optional<executor::CacheTransceiverConfig> cacheTransceiverConfig = std::nullopt)
94-
: CacheTransceiver(cacheManager, commType,
86+
: CacheTransceiver(cacheManager,
9587
executor::kv_cache::CacheState::ModelConfig{numKvHeadsPerLayer, sizePerHead, tokensPerBlock}, worldConfig,
9688
dataType, attentionType, cacheTransceiverConfig)
9789
{
@@ -118,7 +110,6 @@ class CacheTransceiver : public BaseCacheTransceiver
118110

119111
void setContextState(LlmRequest* llmRequest);
120112

121-
CommType mCommType;
122113
std::unique_ptr<DataResponder> mDataResponder;
123114
std::unique_ptr<DataRequester> mDataRequester;
124115
std::vector<std::pair<LlmRequest*, std::future<void>>> mResponderFutures;

cpp/include/tensorrt_llm/executor/executor.h

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1430,18 +1430,29 @@ class LogitsPostProcessorConfig
14301430
class CacheTransceiverConfig
14311431
{
14321432
public:
1433-
explicit CacheTransceiverConfig(std::optional<size_t> maxNumTokens = std::nullopt);
1433+
enum class BackendType : std::uint8_t
1434+
{
1435+
DEFAULT = 0,
1436+
MPI = 1,
1437+
UCX = 2,
1438+
NIXL = 3
1439+
};
1440+
explicit CacheTransceiverConfig(
1441+
std::optional<BackendType> backendType = std::nullopt, std::optional<size_t> maxNumTokens = std::nullopt);
14341442

14351443
bool operator==(CacheTransceiverConfig const& other) const;
1444+
void setBackendType(std::optional<BackendType> backendType);
1445+
void setMaxTokensInBuffer(std::optional<size_t> maxTokensInBuffer);
14361446

1437-
[[nodiscard]] std::optional<size_t> getMaxNumTokens() const;
1438-
void setMaxNumTokens(size_t maxNumTokens);
1447+
[[nodiscard]] std::optional<size_t> getMaxTokensInBuffer() const;
1448+
[[nodiscard]] std::optional<BackendType> getBackendType() const;
14391449

14401450
private:
1451+
std::optional<BackendType> mBackendType;
14411452
/// @brief The maximum number of tokens that the CacheTransceiver's pre-allocated buffer can hold. If the number of
14421453
/// kvCache tokens to be transferred for a single request is greater than this value, the performance of the cache
14431454
/// transfer may be degraded.
1444-
std::optional<size_t> mMaxNumTokens;
1455+
std::optional<size_t> mMaxTokensInBuffer;
14451456
};
14461457

14471458
/// @brief Configuration class for the model executor

cpp/tensorrt_llm/batch_manager/cacheTransBuffer.cpp

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ CacheTransBufferManager::CacheTransBufferManager(
210210
{
211211
auto poolIdx = mCacheManager->getBlockManager().getLayerPoolIdx(layerId);
212212
auto windowSize = static_cast<size_t>(mCacheManager->getBlockManager().getPoolWindowSize(poolIdx));
213-
auto validTokenNum = windowSize < maxNumTokens.value() ? windowSize : maxNumTokens.value();
213+
auto validTokenNum = (windowSize < maxNumTokens.value() ? windowSize : maxNumTokens.value());
214214
bufferSizeFromMaxNumToken += validTokenNum * kvCacheByteSizePerTokenPerLayer;
215215
}
216216
}
@@ -230,26 +230,37 @@ CacheTransBufferManager::CacheTransBufferManager(
230230
TLLM_LOG_INFO(
231231
"CacheTransBufferManager: mMaxNumTokens:%ld, mRecvBufferCount:%ld, "
232232
"mSendBufferCount:%ld,mTransferBufferSize:%ld, mPreAllocBufferSize:%ld,mOnlyUseDynamicBuffer:%d "
233-
"mUseFabricMemory:%d",
233+
"mUseFabricMemory:%d mDataType:%d",
234234
maxNumTokens.has_value() ? maxNumTokens.value() : 0, mRecvBufferCount, mSendBufferCount, mTransferBufferSize,
235-
mPreAllocBufferSize, mOnlyUseDynamicBuffer, mUseFabricMemory);
236-
bool to_allocate = common::getEnvUseMPIKvCache() || common::getEnvUseUCXKvCache() || common::getEnvUseNixlKvCache();
235+
mPreAllocBufferSize, mOnlyUseDynamicBuffer, mUseFabricMemory, mDataType);
237236

238-
TLLM_CHECK_WITH_INFO(to_allocate, "CacheTransBufferManager: to_allocate is false");
239237
allocateBuffer();
240238
}
241239

242-
size_t CacheTransBufferManager::preAllocBufferSize(std::optional<size_t> maxNumTokens)
240+
size_t CacheTransBufferManager::preAllocBufferSize(
241+
std::map<SizeType32, SizeType32> const& cacheSizeBytesPerTokenPerWindow,
242+
std::optional<executor::CacheTransceiverConfig> const& cacheTransceiverConfig)
243243
{
244-
bool to_allocate = common::getEnvUseMPIKvCache() || common::getEnvUseUCXKvCache() || common::getEnvUseNixlKvCache();
245-
if (!to_allocate)
244+
if (!cacheTransceiverConfig.has_value())
246245
{
247246
return 0;
248247
}
248+
if (!cacheTransceiverConfig->getBackendType().has_value())
249+
{
250+
return 0;
251+
}
252+
auto maxNumTokens = cacheTransceiverConfig->getMaxTokensInBuffer();
249253
size_t TransferBufferSize = common::getEnvMemSizeForKVCacheTransferBuffer();
250254
if (maxNumTokens.has_value())
251255
{
252-
TransferBufferSize = maxNumTokens.value();
256+
TransferBufferSize = 0;
257+
for (auto const& [windowSize, cacheSizeBytesPerToken] : cacheSizeBytesPerTokenPerWindow)
258+
{
259+
auto validTokenNum
260+
= (static_cast<size_t>(windowSize) < maxNumTokens.value() ? static_cast<size_t>(windowSize)
261+
: maxNumTokens.value());
262+
TransferBufferSize += validTokenNum * cacheSizeBytesPerToken;
263+
}
253264
}
254265
bool useFabricMemory = FabricMemory::supportFbaricMemory()
255266
&& (!(common::getEnvKVCacheTransferUseSyncBuffer() || common::getEnvKVCacheTransferUseAsyncBuffer()));
@@ -329,6 +340,14 @@ std::tuple<std::vector<runtime::ITensor::SharedPtr>, size_t, bool> CacheTransBuf
329340
size_t bufferCoverTargetNum = std::min(
330341
static_cast<size_t>(targetNum), mTransferBufferSize / (targetBufferEleSize * common::getDTypeSize(mDataType)));
331342
TLLM_LOG_DEBUG("getOrAllocateBuffers bufferCoverTargetNum:%d", bufferCoverTargetNum);
343+
if (bufferCoverTargetNum < static_cast<size_t>(targetNum))
344+
{
345+
TLLM_LOG_WARNING(
346+
"CacheTransceiver getOrAllocateBuffers: bufferCoverTargetNum:%d < targetNum:%d, may use dynamic buffer, "
347+
"it's better to increase MaxTokensInBuffer in cacheTransceiverConfig, otherwise, the performance may "
348+
"be degraded",
349+
bufferCoverTargetNum, targetNum);
350+
}
332351
if (bufferId.has_value())
333352
{
334353
TLLM_CHECK(static_cast<size_t>(bufferId.value()) < concurrenceResource.mBuffers.size());

cpp/tensorrt_llm/batch_manager/cacheTransBuffer.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#pragma once
1919

2020
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
21+
#include "tensorrt_llm/executor/executor.h"
2122
#include "tensorrt_llm/runtime/bufferManager.h"
2223
#include "tensorrt_llm/runtime/iTensor.h"
2324
#include <atomic>
@@ -59,7 +60,8 @@ class CacheTransBufferManager
5960
CacheTransBufferManager(
6061
KVCacheManager::BaseKVCacheManager* cacheManager, std::optional<size_t> maxNumTokens = std::nullopt);
6162

62-
static size_t preAllocBufferSize(std::optional<size_t> maxNumTokens = std::nullopt);
63+
static size_t preAllocBufferSize(std::map<SizeType32, SizeType32> const& cacheSizeBytesPerTokenPerWindow,
64+
std::optional<executor::CacheTransceiverConfig> const& cacheTransceiverConfig = std::nullopt);
6365

6466
std::optional<int> assignBufferIndexForSend();
6567
void freeBufferIndexForSend(std::optional<int> bufferId);

0 commit comments

Comments
 (0)