Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion cpp/include/tensorrt_llm/executor/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -1484,7 +1484,8 @@ class ExecutorConfig
std::optional<GuidedDecodingConfig> guidedDecodingConfig = std::nullopt,
std::optional<std::vector<AdditionalModelOutput>> additionalModelOutputs = std::nullopt,
std::optional<CacheTransceiverConfig> cacheTransceiverConfig = std::nullopt,
bool gatherGenerationLogits = false, bool promptTableOffloading = false, bool enableTrtOverlap = false);
bool gatherGenerationLogits = false, bool promptTableOffloading = false, bool enableTrtOverlap = false,
bool failFastOnAttentionWindowTooLarge = false);

[[nodiscard]] SizeType32 getMaxBeamWidth() const;
[[nodiscard]] SchedulerConfig getSchedulerConfig() const;
Expand Down Expand Up @@ -1519,6 +1520,7 @@ class ExecutorConfig
[[nodiscard]] bool getPromptTableOffloading() const;
[[nodiscard]] std::optional<CacheTransceiverConfig> getCacheTransceiverConfig() const;
[[nodiscard]] bool getEnableTrtOverlap() const;
[[nodiscard]] bool getFailFastOnAttentionWindowTooLarge() const;

void setMaxBeamWidth(SizeType32 maxBeamWidth);
void setMaxBatchSize(SizeType32 maxBatchSize);
Expand Down Expand Up @@ -1548,6 +1550,7 @@ class ExecutorConfig
void setPromptTableOffloading(bool promptTableOffloading);
void setCacheTransceiverConfig(CacheTransceiverConfig const& cacheTransceiverConfig);
void setEnableTrtOverlap(bool enableTrtOverlap);
void setFailFastOnAttentionWindowTooLarge(bool failFastOnAttentionWindowTooLarge);

private:
friend class Serialization;
Expand Down Expand Up @@ -1634,6 +1637,10 @@ class ExecutorConfig

/// @brief Controls whether preparation and TRT engine execution should be overlapped.
bool mEnableTrtOverlap{false};

/// @brief Controls whether to fail fast when attention window is too large to fit even a single sequence in the KV
/// cache.
bool mFailFastOnAttentionWindowTooLarge{false};
};

struct KVCacheCreatedData
Expand Down
32 changes: 22 additions & 10 deletions cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,27 +296,27 @@ TrtGptModelInflightBatching::TrtGptModelInflightBatching(std::shared_ptr<nvinfer

auto const [freePrimaryMemBytes, freeSecondaryMemBytes]
= BaseKVCacheManager::calculateFreeMemBytes(mRuntime->getBufferManager(), kvCacheConfig);

if (mModelConfig.useCrossAttention())
{
TLLM_CHECK_WITH_INFO(kvCacheConfig.getCrossKvCacheFraction().has_value(),
"Must set crossKvCacheFraction for encoder-decoder model");
auto const crossKvCacheFraction = kvCacheConfig.getCrossKvCacheFraction().value();
mKvCacheManager = createKvCacheManager(kvCacheConfig, KvCacheType::kSELF,
freePrimaryMemBytes * (1.0f - crossKvCacheFraction),
freeSecondaryMemBytes * (1.0f - crossKvCacheFraction), cacheTransPreAllocaSize);
mCrossKvCacheManager
= createKvCacheManager(kvCacheConfig, KvCacheType::kCROSS, freePrimaryMemBytes * crossKvCacheFraction,
freeSecondaryMemBytes * crossKvCacheFraction, cacheTransPreAllocaSize);
freeSecondaryMemBytes * (1.0f - crossKvCacheFraction), cacheTransPreAllocaSize,
executorConfig.getFailFastOnAttentionWindowTooLarge());
mCrossKvCacheManager = createKvCacheManager(kvCacheConfig, KvCacheType::kCROSS,
freePrimaryMemBytes * crossKvCacheFraction, freeSecondaryMemBytes * crossKvCacheFraction,
cacheTransPreAllocaSize, executorConfig.getFailFastOnAttentionWindowTooLarge());
TLLM_LOG_INFO("This is an Encoder-Decoder model, set %0.1f cross KV cache fraction based on the config.",
crossKvCacheFraction);
}
else
{
TLLM_CHECK_WITH_INFO(!kvCacheConfig.getCrossKvCacheFraction().has_value(),
"Do not set crossKvCacheFraction for decoder-only model");
mKvCacheManager = createKvCacheManager(
kvCacheConfig, KvCacheType::kSELF, freePrimaryMemBytes, freeSecondaryMemBytes, cacheTransPreAllocaSize);
mKvCacheManager = createKvCacheManager(kvCacheConfig, KvCacheType::kSELF, freePrimaryMemBytes,
freeSecondaryMemBytes, cacheTransPreAllocaSize, executorConfig.getFailFastOnAttentionWindowTooLarge());
}

mCacheTransceiver
Expand Down Expand Up @@ -550,7 +550,8 @@ void TrtGptModelInflightBatching::reshapeKvTensors(OffsetTableDimensions const&
using BlocksPerWindow = std::map<SizeType32, std::tuple<SizeType32, SizeType32>>;

std::pair<BlocksPerWindow, std::vector<SizeType32>>
TrtGptModelInflightBatching::clampWindowSizesToFitAtLeastOneSequence(BlocksPerWindow const& blocksPerWindow)
TrtGptModelInflightBatching::clampWindowSizesToFitAtLeastOneSequence(
BlocksPerWindow const& blocksPerWindow, bool const failFastOnAttentionWindowTooLarge)
{
// At this point, we can only validate that the cheapest sequence in terms of kv-cache resources still fits. More
// validation is needed on a per-request basis, once the prompt / output lengths and the actual beam width are
Expand Down Expand Up @@ -591,6 +592,16 @@ TrtGptModelInflightBatching::clampWindowSizesToFitAtLeastOneSequence(BlocksPerWi
}
TLLM_LOG_WARNING("maxAttentionWindowVec too large to fit at least one sequence in kvCache. Old: %s, New: %s",
common::vec2str(getMaxAttentionWindowVec()).c_str(), common::vec2str(newMaxAttentionWindowVec).c_str());

if (failFastOnAttentionWindowTooLarge)
{
throw std::runtime_error(
"Attention window too large to fit even a single sequence in the KV cache. Failing fast rather than "
"attempting an adjustment of the window sizes. "
"Old: "
+ common::vec2str(getMaxAttentionWindowVec()) + ", New: " + common::vec2str(newMaxAttentionWindowVec));
}

setMaxAttentionWindowVec(newMaxAttentionWindowVec);
if (getMaxSequenceLen() > getMaxAttentionWindow())
{
Expand All @@ -613,7 +624,7 @@ TrtGptModelInflightBatching::clampWindowSizesToFitAtLeastOneSequence(BlocksPerWi

std::unique_ptr<kv_cache_manager::KVCacheManager> TrtGptModelInflightBatching::createKvCacheManager(
KvCacheConfig const& kvCacheConfig, KvCacheType kvCacheType, uint64_t freePrimaryMemBytes,
uint64_t freeSecondaryMemBytes, size_t extraCostMemory)
uint64_t freeSecondaryMemBytes, size_t extraCostMemory, bool const failFastOnAttentionWindowTooLarge)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
bool isCrossAttention = kvCacheType == KvCacheType::kCROSS;
Expand Down Expand Up @@ -657,7 +668,8 @@ std::unique_ptr<kv_cache_manager::KVCacheManager> TrtGptModelInflightBatching::c
// and user also didn't provide maxAttentionWindow, which leads it to be equal to maxSeqLen
if (kvCacheType == KvCacheType::kSELF)
{
std::tie(blocksPerWindow, maxAttentionWindowVec) = clampWindowSizesToFitAtLeastOneSequence(blocksPerWindow);
std::tie(blocksPerWindow, maxAttentionWindowVec)
= clampWindowSizesToFitAtLeastOneSequence(blocksPerWindow, failFastOnAttentionWindowTooLarge);
}

kv_cache_manager::TempAttentionWindowInputs tempAttentionWindowInputs;
Expand Down
7 changes: 5 additions & 2 deletions cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.h
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,8 @@ class TrtGptModelInflightBatching : public TrtGptModel
void createBuffers(executor::DecodingConfig const& decodingConfig,
std::optional<std::vector<executor::AdditionalModelOutput>> const& additionalModelOutputs);
std::unique_ptr<KVCacheManager> createKvCacheManager(KvCacheConfig const& kvCacheConfig, KvCacheType kvCacheType,
uint64_t freePrimaryMemBytes, uint64_t freeSecondaryMemBytes, size_t extraCostMemory);
uint64_t freePrimaryMemBytes, uint64_t freeSecondaryMemBytes, size_t extraCostMemory,
bool const failFastOnAttentionWindowTooLarge = false);
void createRnnStateManager();
void createCustomAllReduceWorkspace();
void createRuntimePerfKnobsTensor(executor::ExtendedRuntimePerfKnobConfig const& extendedRuntimePerfKnobConfig);
Expand Down Expand Up @@ -378,9 +379,11 @@ class TrtGptModelInflightBatching : public TrtGptModel
/// window.
///
/// @param blocksPerWindow map of window size to number of blocks.
/// @param failFastOnAttentionWindowTooLarge if true, the function will report a runtime error if the attention
/// window is too large to fit even a single sequence in the KV cache.
/// @return pair of new blocks per window and new maxAttentionWindowVec
[[nodiscard]] std::pair<BlocksPerWindow, std::vector<SizeType32>> clampWindowSizesToFitAtLeastOneSequence(
BlocksPerWindow const& blocksPerWindow);
BlocksPerWindow const& blocksPerWindow, bool const failFastOnAttentionWindowTooLarge = false);

/// @brief Change the speculative decoding mode.
void changeSpecDecMode(ScheduledRequests const& scheduledRequests);
Expand Down
13 changes: 12 additions & 1 deletion cpp/tensorrt_llm/executor/executorConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ ExecutorConfig::ExecutorConfig(SizeType32 maxBeamWidth, SchedulerConfig schedule
std::optional<SpeculativeDecodingConfig> specDecConfig, std::optional<GuidedDecodingConfig> guidedDecodingConfig,
std::optional<std::vector<AdditionalModelOutput>> additionalModelOutputs,
std::optional<CacheTransceiverConfig> cacheTransceiverConfig, bool gatherGenerationLogits,
bool promptTableOffloading, bool enableTrtOverlap)
bool promptTableOffloading, bool enableTrtOverlap, bool failFastOnAttentionWindowTooLarge)
: mMaxBeamWidth(maxBeamWidth)
, mSchedulerConfig(std::move(schedulerConfig))
, mKvCacheConfig(std::move(kvCacheConfig))
Expand Down Expand Up @@ -63,6 +63,7 @@ ExecutorConfig::ExecutorConfig(SizeType32 maxBeamWidth, SchedulerConfig schedule
, mGatherGenerationLogits(gatherGenerationLogits)
, mPromptTableOffloading(promptTableOffloading)
, mEnableTrtOverlap(enableTrtOverlap)
, mFailFastOnAttentionWindowTooLarge(failFastOnAttentionWindowTooLarge)
{
TLLM_CHECK(iterStatsMaxIterations >= 0);
TLLM_CHECK(requestStatsMaxIterations >= 0);
Expand Down Expand Up @@ -222,6 +223,11 @@ bool ExecutorConfig::getEnableTrtOverlap() const
return mEnableTrtOverlap;
}

bool ExecutorConfig::getFailFastOnAttentionWindowTooLarge() const
{
return mFailFastOnAttentionWindowTooLarge;
}

// setters

void ExecutorConfig::setMaxBeamWidth(SizeType32 maxBeamWidth)
Expand Down Expand Up @@ -371,4 +377,9 @@ void ExecutorConfig::setEnableTrtOverlap(bool enableTrtOverlap)
mEnableTrtOverlap = enableTrtOverlap;
}

void ExecutorConfig::setFailFastOnAttentionWindowTooLarge(bool failFastOnAttentionWindowTooLarge)
{
mFailFastOnAttentionWindowTooLarge = failFastOnAttentionWindowTooLarge;
}

} // namespace tensorrt_llm::executor
15 changes: 10 additions & 5 deletions cpp/tensorrt_llm/pybind/executor/executorConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ void initConfigBindings(pybind11::module_& m)
c.getExtendedRuntimePerfKnobConfig(), c.getDebugConfig(), c.getRecvPollPeriodMs(),
c.getMaxSeqIdleMicroseconds(), c.getSpecDecConfig(), c.getGuidedDecodingConfig(),
c.getAdditionalModelOutputs(), c.getCacheTransceiverConfig(), c.getGatherGenerationLogits(),
c.getPromptTableOffloading(), c.getEnableTrtOverlap());
c.getPromptTableOffloading(), c.getEnableTrtOverlap(), c.getFailFastOnAttentionWindowTooLarge());
auto pickle_tuple = py::make_tuple(cpp_states, py::getattr(self, "__dict__"));
return pickle_tuple;
};
Expand All @@ -472,7 +472,7 @@ void initConfigBindings(pybind11::module_& m)

// Restore C++ data
auto cpp_states = state[0].cast<py::tuple>();
if (cpp_states.size() != 28)
if (cpp_states.size() != 29)
{
throw std::runtime_error("Invalid cpp_states!");
}
Expand Down Expand Up @@ -505,7 +505,8 @@ void initConfigBindings(pybind11::module_& m)
cpp_states[24].cast<std::optional<tle::CacheTransceiverConfig>>(), // CacheTransceiverConfig
cpp_states[25].cast<bool>(), // GatherGenerationLogits
cpp_states[26].cast<bool>(), // PromptTableOffloading
cpp_states[27].cast<bool>() // EnableTrtOverlap
cpp_states[27].cast<bool>(), // EnableTrtOverlap
cpp_states[28].cast<bool>() // FailFastOnAttentionWindowTooLarge
);

auto py_state = state[1].cast<py::dict>();
Expand Down Expand Up @@ -542,7 +543,8 @@ void initConfigBindings(pybind11::module_& m)
std::optional<tle::CacheTransceiverConfig>, // CacheTransceiverConfig
bool, // GatherGenerationLogits
bool, // PromptTableOffloading
bool // EnableTrtOverlap
bool, // EnableTrtOverlap
bool // FailFastOnAttentionWindowTooLarge
>(),
py::arg("max_beam_width") = 1, py::arg_v("scheduler_config", tle::SchedulerConfig(), "SchedulerConfig()"),
py::arg_v("kv_cache_config", tle::KvCacheConfig(), "KvCacheConfig()"),
Expand All @@ -563,7 +565,7 @@ void initConfigBindings(pybind11::module_& m)
py::arg("spec_dec_config") = py::none(), py::arg("guided_decoding_config") = py::none(),
py::arg("additional_model_outputs") = py::none(), py::arg("cache_transceiver_config") = py::none(),
py::arg("gather_generation_logits") = false, py::arg("mm_embedding_offloading") = false,
py::arg("enable_trt_overlap") = false)
py::arg("enable_trt_overlap") = false, py::arg("fail_fast_on_attention_window_too_large") = false)
.def_property("max_beam_width", &tle::ExecutorConfig::getMaxBeamWidth, &tle::ExecutorConfig::setMaxBeamWidth)
.def_property("max_batch_size", &tle::ExecutorConfig::getMaxBatchSize, &tle::ExecutorConfig::setMaxBatchSize)
.def_property("max_num_tokens", &tle::ExecutorConfig::getMaxNumTokens, &tle::ExecutorConfig::setMaxNumTokens)
Expand Down Expand Up @@ -613,6 +615,9 @@ void initConfigBindings(pybind11::module_& m)
&tle::ExecutorConfig::setPromptTableOffloading)
.def_property(
"enable_trt_overlap", &tle::ExecutorConfig::getEnableTrtOverlap, &tle::ExecutorConfig::setEnableTrtOverlap)
.def_property("fail_fast_on_attention_window_too_large",
&tle::ExecutorConfig::getFailFastOnAttentionWindowTooLarge,
&tle::ExecutorConfig::setFailFastOnAttentionWindowTooLarge)
.def(py::pickle(executorConfigGetState, executorConfigSetState));
}

Expand Down
15 changes: 14 additions & 1 deletion examples/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,13 @@ def parse_arguments(args=None):
default=False,
action='store_true',
help="Run several 10 iterations to profile the inference latencies.")
parser.add_argument(
'--fail_fast_on_attention_window_too_large',
action='store_true',
default=False,
help=
'Exit with runtime error when attention window is too large to fit even a single sequence in the KV cache.'
)

parser = add_common_args(parser)

Expand Down Expand Up @@ -455,6 +462,8 @@ def main(args):
gpu_weights_percent=args.gpu_weights_percent,
max_output_len=args.max_output_len,
enable_context_fmha_fp32_acc=args.enable_context_fmha_fp32_acc,
fail_fast_on_attention_window_too_large=args.
fail_fast_on_attention_window_too_large,
)
if args.medusa_choices is not None:
args.medusa_choices = ast.literal_eval(args.medusa_choices)
Expand Down Expand Up @@ -549,6 +558,8 @@ def main(args):
eagle_choices=args.eagle_choices,
return_all_generated_tokens=args.return_all_generated_tokens,
input_token_extra_ids=input_token_extra_ids,
fail_fast_on_attention_window_too_large=args.
fail_fast_on_attention_window_too_large,
language_adapter_uids=args.language_task_uids)
torch.cuda.synchronize()

Expand Down Expand Up @@ -680,7 +691,9 @@ def main(args):
return_dict=True,
return_all_generated_tokens=args.
return_all_generated_tokens,
input_token_extra_ids=input_token_extra_ids)
input_token_extra_ids=input_token_extra_ids,
fail_fast_on_attention_window_too_large=args.
fail_fast_on_attention_window_too_large)
torch.cuda.synchronize()
tensorrt_llm.profiler.stop("tmp")

Expand Down
Loading