diff --git a/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp b/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp index b8855af568d..be7397182ec 100644 --- a/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp +++ b/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp @@ -559,7 +559,7 @@ TrtGptModelInflightBatching::clampWindowSizesToFitAtLeastOneSequence(BlocksPerWi TLLM_LOG_WARNING("maxAttentionWindowVec too large to fit at least one sequence in kvCache. Old: %s, New: %s", common::vec2str(getMaxAttentionWindowVec()).c_str(), common::vec2str(newMaxAttentionWindowVec).c_str()); setMaxAttentionWindowVec(newMaxAttentionWindowVec); - if (getMaxSequenceLen() < getMaxAttentionWindow()) + if (getMaxSequenceLen() > getMaxAttentionWindow()) { TLLM_LOG_WARNING("maxSequenceLen is reduced to maxAttentionWindow: %d", getMaxAttentionWindow()); setMaxSequenceLen(getMaxAttentionWindow()); diff --git a/cpp/tests/batch_manager/trtGptModelTest.cpp b/cpp/tests/batch_manager/trtGptModelTest.cpp index d45aedcabb1..1f1f5743725 100644 --- a/cpp/tests/batch_manager/trtGptModelTest.cpp +++ b/cpp/tests/batch_manager/trtGptModelTest.cpp @@ -870,6 +870,11 @@ class TrtGptModelIfbHelper : public TrtGptModelInflightBatching { return TrtGptModelInflightBatching::getKVCacheManager(); } + + [[nodiscard]] SizeType32 getMaxAttentionWindow() const + { + return TrtGptModelInflightBatching::getMaxAttentionWindow(); + } }; TEST_F(TrtGptModelTest, KVCacheReuseChunked) @@ -1201,4 +1206,23 @@ TEST_F(LlamaModelLADTest, SeamlessLookaheadDecoding) } } +TEST_F(TrtGptModelTest, ClampSeqLenToAttentionWindow) +{ + auto constexpr maxAttentionWindow = 65536; + auto constexpr maxSequenceLen = maxAttentionWindow + 1; + + TrtGptModelOptionalParams optionalParams; + optionalParams.kvCacheConfig.maxAttentionWindowVec = std::vector{maxAttentionWindow}; + optionalParams.kvCacheConfig.freeGpuMemoryFraction = 0.0001; // minuscule amount of memory to force a clamp + optionalParams.maxBeamWidth = mBeamWidth; + + auto modelConfig = mModelConfig; + modelConfig.setMaxSequenceLen(maxSequenceLen); + + auto trtGptModel + = std::make_shared(mLogger, modelConfig, mWorldConfig, *mRawEngine, true, optionalParams); + EXPECT_LT(trtGptModel->getMaxAttentionWindow(), maxAttentionWindow); + EXPECT_EQ(trtGptModel->getMaxSequenceLen(), trtGptModel->getMaxAttentionWindow()); +} + } // namespace tensorrt_llm::batch_manager