Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 27 additions & 6 deletions cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,16 @@ class GenerationRequest
return mKvCacheRetentionConfig.getDecodeDurationMs();
}

[[nodiscard]] executor::KvCacheTransferMode getTransferMode() const
{
return mKvCacheRetentionConfig.getTransferMode();
}

[[nodiscard]] std::optional<std::string> const& getDirectory() const
{
return mKvCacheRetentionConfig.getDirectory();
}

// @brief Check whether the sequence uses cyclic KV cache.
// @return `true` if we have begun overwriting the beginning of the sequence's KV cache.
// @details If `true`, we cannot store the sequence's KV cache for reuse.
Expand Down Expand Up @@ -691,11 +701,14 @@ class WindowBlockManager

//! \brief Bring offloaded block from secondary to primary memory.
//! \details Does nothing if block is already in primary memory.
void onboardBlock(BlockPtr const& offloadBlock);
void onboardBlock(BlockPtr const& offloadBlock,
executor::KvCacheTransferMode mode = executor::KvCacheTransferMode::DRAM,
std::optional<std::string> directory = std::nullopt);

//! \brief Bring block from primary to secondary memory.
//! \details Does nothing if block is already in secondary memory.
void offloadBlock(BlockPtr const& block);
void offloadBlock(BlockPtr const& block, executor::KvCacheTransferMode mode = executor::KvCacheTransferMode::DRAM,
std::optional<std::string> directory = std::nullopt);

//! \brief Find first new block that must be allocated for context phase and return it's concatenated token vectors.
//! \details Only full blocks are considered.
Expand Down Expand Up @@ -749,7 +762,9 @@ class WindowBlockManager
//! \param sequence Sequence to which blocks are assigned.
//! \return Number of matched tokens from loaded blocks.
SizeType32 loadOrAllocateBlocks(std::vector<BlockKey> const& blockKeys, SizeType32 numContextBlocks,
GenerationRequest& sequence, std::vector<executor::RetentionPriorityAndDuration> const& perBlockRetentions);
GenerationRequest& sequence, std::vector<executor::RetentionPriorityAndDuration> const& perBlockRetentions,
executor::KvCacheTransferMode mode = executor::KvCacheTransferMode::DRAM,
std::optional<std::string> directory = std::nullopt);

//! \brief Free block and all it's descendants. This makes block a claimed leaf block.
void freeChildren(BlockPtr const& block, executor::RetentionPriority priority,
Expand All @@ -758,7 +773,9 @@ class WindowBlockManager
//! \brief Find block least likely to be reused, free it if necessary and return.
[[nodiscard]] BlockPtr getFreeBlock(
executor::RetentionPriority = executor::KvCacheRetentionConfig::kDefaultRetentionPriority,
std::optional<std::chrono::milliseconds> durationMs = std::nullopt);
std::optional<std::chrono::milliseconds> durationMs = std::nullopt,
executor::KvCacheTransferMode mode = executor::KvCacheTransferMode::DRAM,
std::optional<std::string> directory = std::nullopt);

//! \brief Free block from previous block and claim it from free blocks list.
void claimLeafBlock(BlockPtr const& block, std::optional<executor::RetentionPriority> priority = std::nullopt,
Expand Down Expand Up @@ -894,11 +911,15 @@ class BlockManager

//! \brief Bring block from primary to secondary memory for window size.
//! \details Does nothing if block is already in primary memory.
void onboardBlock(BlockPtr const& offloadBlock, SizeType32 windowSize);
void onboardBlock(BlockPtr const& offloadBlock, SizeType32 windowSize,
executor::KvCacheTransferMode mode = executor::KvCacheTransferMode::DRAM,
std::optional<std::string> directory = std::nullopt);

//! \brief Bring block from primary to secondary memory for window size.
//! \details Does nothing if block is already in secondary memory.
void offloadBlock(BlockPtr const& block, SizeType32 windowSize);
void offloadBlock(BlockPtr const& block, SizeType32 windowSize,
executor::KvCacheTransferMode mode = executor::KvCacheTransferMode::DRAM,
std::optional<std::string> directory = std::nullopt);

void storeBlocks(std::vector<BlockKey> const& blockKeys, std::vector<KVCacheBlock::IdType> const& blockIds,
SizeType32 windowSize)
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/tensorrt_llm/executor/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,7 @@ class KvCacheRetentionConfig
[[nodiscard]] RetentionPriority getDecodeRetentionPriority() const;
[[nodiscard]] std::optional<std::chrono::milliseconds> getDecodeDurationMs() const;
[[nodiscard]] KvCacheTransferMode getTransferMode() const;
[[nodiscard]] std::optional<std::string> getDirectory() const;
[[nodiscard]] std::optional<std::string> const& getDirectory() const;

/// @brief Convert the token range data into an entry per kv block. Returns a tuple of vectors corresponding to the
/// priorities and durations for each block.
Expand Down
59 changes: 37 additions & 22 deletions cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -843,8 +843,9 @@ void WindowBlockManager::freeChildren(
claimLeafBlock(block, priority, durationMs);
}

BlockPtr WindowBlockManager::getFreeBlock(
executor::RetentionPriority priority, std::optional<std::chrono::milliseconds> durationMs)
BlockPtr WindowBlockManager::getFreeBlock(executor::RetentionPriority priority,
std::optional<std::chrono::milliseconds> durationMs, executor::KvCacheTransferMode mode,
std::optional<std::string> directory)
{
// eviction policy get free primary block
auto [block, canOffload] = mEvictionPolicy->getFreeBlock(kPrimaryLevel);
Expand All @@ -865,7 +866,7 @@ BlockPtr WindowBlockManager::getFreeBlock(
mEvictionPolicy->claimBlock(block);
// Offload block in primary memory before repurposing
auto offloadBlock = std::get<0>(mEvictionPolicy->getFreeBlock(kSecondaryLevel));
mTransferManager->offload(block, offloadBlock, mPools);
mTransferManager->offload(block, offloadBlock, mPools, 0, mode, directory);
// swap linear block offsets (i.e. make block the offload block)
block->swapMemoryPoolBlockOffset(offloadBlock);

Expand Down Expand Up @@ -917,17 +918,20 @@ void BlockManager::setOffsets(tk::KVCacheIndex* offsetsPtr, nvinfer1::Dims const
mWindowBlockManagers.at(windowSize).setOffsets(offsetsPtr, offsetsShape, beamIdx, blockIdx, blockId);
}

void BlockManager::onboardBlock(BlockPtr const& offloadBlock, SizeType32 windowSize)
void BlockManager::onboardBlock(BlockPtr const& offloadBlock, SizeType32 windowSize, executor::KvCacheTransferMode mode,
std::optional<std::string> directory)
{
mWindowBlockManagers.at(windowSize).onboardBlock(offloadBlock);
mWindowBlockManagers.at(windowSize).onboardBlock(offloadBlock, mode, directory);
}

void WindowBlockManager::onboardBlock(BlockPtr const& offloadBlock)
void WindowBlockManager::onboardBlock(
BlockPtr const& offloadBlock, executor::KvCacheTransferMode mode, std::optional<std::string> directory)
{
if (mOnboardBlocks && !offloadBlock->isPrimary())
{
auto block = getFreeBlock();
mTransferManager->onboard(offloadBlock, block, mPools);
auto block
= getFreeBlock(executor::KvCacheRetentionConfig::kDefaultRetentionPriority, std::nullopt, mode, directory);
mTransferManager->onboard(offloadBlock, block, mPools, 0, mode, directory);
// swap linear block offsets (i.e. make block the offload block and vice versa)
offloadBlock->swapMemoryPoolBlockOffset(block);

Expand All @@ -942,20 +946,22 @@ void WindowBlockManager::onboardBlock(BlockPtr const& offloadBlock)
}
}

void BlockManager::offloadBlock(BlockPtr const& block, SizeType32 windowSize)
void BlockManager::offloadBlock(BlockPtr const& block, SizeType32 windowSize, executor::KvCacheTransferMode mode,
std::optional<std::string> directory)
{
mWindowBlockManagers.at(windowSize).offloadBlock(block);
mWindowBlockManagers.at(windowSize).offloadBlock(block, mode, directory);
}

void WindowBlockManager::offloadBlock(BlockPtr const& block)
void WindowBlockManager::offloadBlock(
BlockPtr const& block, executor::KvCacheTransferMode mode, std::optional<std::string> directory)
{
if (mOnboardBlocks && block->isPrimary())
{
// Offload block in primary memory before repurposing
auto offloadBlock = std::get<0>(mEvictionPolicy->getFreeBlock(kSecondaryLevel));
// If we're swapping a block to secondary memory, maintain the prior priority values.
mEvictionPolicy->claimBlock(offloadBlock);
mTransferManager->offload(block, offloadBlock, mPools);
mTransferManager->offload(block, offloadBlock, mPools, 0, mode, directory);
// swap linear block offsets (i.e. make block the offload block)
block->swapMemoryPoolBlockOffset(offloadBlock);

Expand Down Expand Up @@ -1009,7 +1015,8 @@ bool WindowBlockManager::blockInRadixTree(BlockPtr const& block)
}

SizeType32 WindowBlockManager::loadOrAllocateBlocks(std::vector<BlockKey> const& blockKeys, SizeType32 numContextBlocks,
GenerationRequest& sequence, std::vector<executor::RetentionPriorityAndDuration> const& perBlockRetentions)
GenerationRequest& sequence, std::vector<executor::RetentionPriorityAndDuration> const& perBlockRetentions,
executor::KvCacheTransferMode mode, std::optional<std::string> directory)
{
SizeType32 numMatchedTokens{0};
auto searchRoot = mCachedBlocksRoot;
Expand Down Expand Up @@ -1043,8 +1050,9 @@ SizeType32 WindowBlockManager::loadOrAllocateBlocks(std::vector<BlockKey> const&
if (matchingBlock->hasRefs() || !matchingBlock->isLeaf())
{
// Somebody else is using block or it is not a leaf, copy reusable tokens
auto newBlock = getFreeBlock(matchingBlock->getPriority(), matchingBlock->getDurationMs());
mTransferManager->onboard(matchingBlock, newBlock, mPools, numMatched);
auto newBlock
= getFreeBlock(matchingBlock->getPriority(), matchingBlock->getDurationMs(), mode, directory);
mTransferManager->onboard(matchingBlock, newBlock, mPools, numMatched, mode, directory);
// TODO: (optional) Send out event
matchingBlock = newBlock;
TLLM_LOG_DEBUG("%s::loadOrAllocateBlocks - Copied partially filled block %d", mLogPrefix.c_str(),
Expand All @@ -1068,7 +1076,7 @@ SizeType32 WindowBlockManager::loadOrAllocateBlocks(std::vector<BlockKey> const&
TLLM_LOG_DEBUG("%s::loadOrAllocateBlocks - Matched full block %d", mLogPrefix.c_str(), matchingBlockId);
searchRoot = matchingBlock;
}
onboardBlock(matchingBlock);
onboardBlock(matchingBlock, mode, directory);
addBlockToAllBeams(matchingBlock, sequence);
// TODO: only add once for reused blocks
++mReusedBlocks;
Expand All @@ -1084,7 +1092,7 @@ SizeType32 WindowBlockManager::loadOrAllocateBlocks(std::vector<BlockKey> const&
// If we haven't set a priority, set it to the default priority level (low)
auto freeBlock = getFreeBlock(perBlockRetentions[bi].retentionPriority.value_or(
executor::KvCacheRetentionConfig::kDefaultRetentionPriority),
perBlockRetentions[bi].durationMs);
perBlockRetentions[bi].durationMs, mode, directory);
addBlockToAllBeams(freeBlock, sequence);
TLLM_LOG_DEBUG("%s::loadOrAllocateBlocks - No match, allocated new block %d for sequence %lu",
mLogPrefix.c_str(), freeBlock->getBlockId(), sequence.getRequestId());
Expand All @@ -1110,7 +1118,7 @@ SizeType32 WindowBlockManager::loadOrAllocateBlocks(std::vector<BlockKey> const&
// If we haven't set a priority, set it to the default priority level (low)
auto freeBlock = getFreeBlock(perBlockRetentions[bi].retentionPriority.value_or(
executor::KvCacheRetentionConfig::kDefaultRetentionPriority),
perBlockRetentions[bi].durationMs);
perBlockRetentions[bi].durationMs, mode, directory);
addBlockToBeam(freeBlock, sequence, beamIdx);
if (blockItr != blockKeys.end())
{
Expand Down Expand Up @@ -1179,9 +1187,13 @@ void WindowBlockManager::addSequence(
auto perBlockRetentions = config.value_or(executor::KvCacheRetentionConfig())
.getPerBlockRetentionPriorityDuration(getTokensPerBlock(), inputLength);

auto mode = config.value_or(executor::KvCacheRetentionConfig()).getTransferMode();
auto directory = config.value_or(executor::KvCacheRetentionConfig()).getDirectory();

TLLM_CHECK(perBlockRetentions.size() == (size_t) numContextBlocks);

auto const prepopulatedPromptLen = loadOrAllocateBlocks(blockKeys, numContextBlocks, sequence, perBlockRetentions);
auto const prepopulatedPromptLen
= loadOrAllocateBlocks(blockKeys, numContextBlocks, sequence, perBlockRetentions, mode, directory);
mReusedTokens += static_cast<double>(prepopulatedPromptLen);
mTotalInputTokens += static_cast<double>(uniqueTokens.size());
llmRequest.setPrepopulatedPromptLen(prepopulatedPromptLen, getTokensPerBlock());
Expand Down Expand Up @@ -1250,7 +1262,8 @@ void WindowBlockManager::allocateBlock(GenerationRequest& sequence, bool shareAm
if (shareAmongBeams)
{
// add same block to all beams
auto block = getFreeBlock(sequence.getDecodeRetentionPriority(), sequence.getDecodeDurationMs());
auto block = getFreeBlock(sequence.getDecodeRetentionPriority(), sequence.getDecodeDurationMs(),
sequence.getTransferMode(), sequence.getDirectory());
for (auto beamIdx = 0; beamIdx < beamWidth; ++beamIdx)
{
addBlockToBeam(block, sequence, beamIdx);
Expand All @@ -1261,7 +1274,8 @@ void WindowBlockManager::allocateBlock(GenerationRequest& sequence, bool shareAm
// add different block to each beam
for (auto beamIdx = 0; beamIdx < beamWidth; ++beamIdx)
{
auto block = getFreeBlock(sequence.getDecodeRetentionPriority(), sequence.getDecodeDurationMs());
auto block = getFreeBlock(sequence.getDecodeRetentionPriority(), sequence.getDecodeDurationMs(),
sequence.getTransferMode(), sequence.getDirectory());
addBlockToBeam(block, sequence, beamIdx);
}
}
Expand Down Expand Up @@ -1362,7 +1376,8 @@ void WindowBlockManager::replaceSharedBlock(GenerationRequest& sequence, SizeTyp
TLLM_CHECK_WITH_INFO(hasFreeBlocks(beamWidth), "Can't allocate new blocks. No free blocks left.");
for (auto beamIdx = 0; beamIdx < beamWidth; ++beamIdx)
{
auto block = getFreeBlock();
auto block = getFreeBlock(executor::KvCacheRetentionConfig::kDefaultRetentionPriority, std::nullopt,
sequence.getTransferMode(), sequence.getDirectory());
block->incRefCount();
if (sequence.getCacheBlockIds(mWindowSize).at(beamIdx).size() == 0)
{
Expand Down
2 changes: 1 addition & 1 deletion cpp/tensorrt_llm/executor/kvCacheRetentionConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ KvCacheTransferMode KvCacheRetentionConfig::getTransferMode() const
return mTransferMode;
}

std::optional<std::string> KvCacheRetentionConfig::getDirectory() const
std::optional<std::string> const& KvCacheRetentionConfig::getDirectory() const
{
return mDirectory;
}
Expand Down