diff --git a/cpp/include/tensorrt_llm/executor/executor.h b/cpp/include/tensorrt_llm/executor/executor.h index bba3c31a014..6d592654ffd 100644 --- a/cpp/include/tensorrt_llm/executor/executor.h +++ b/cpp/include/tensorrt_llm/executor/executor.h @@ -1484,7 +1484,8 @@ class ExecutorConfig std::optional guidedDecodingConfig = std::nullopt, std::optional> additionalModelOutputs = std::nullopt, std::optional cacheTransceiverConfig = std::nullopt, - bool gatherGenerationLogits = false, bool promptTableOffloading = false, bool enableTrtOverlap = false); + bool gatherGenerationLogits = false, bool promptTableOffloading = false, bool enableTrtOverlap = false, + bool failFastOnAttentionWindowTooLarge = false); [[nodiscard]] SizeType32 getMaxBeamWidth() const; [[nodiscard]] SchedulerConfig getSchedulerConfig() const; @@ -1519,6 +1520,7 @@ class ExecutorConfig [[nodiscard]] bool getPromptTableOffloading() const; [[nodiscard]] std::optional getCacheTransceiverConfig() const; [[nodiscard]] bool getEnableTrtOverlap() const; + [[nodiscard]] bool getFailFastOnAttentionWindowTooLarge() const; void setMaxBeamWidth(SizeType32 maxBeamWidth); void setMaxBatchSize(SizeType32 maxBatchSize); @@ -1548,6 +1550,7 @@ class ExecutorConfig void setPromptTableOffloading(bool promptTableOffloading); void setCacheTransceiverConfig(CacheTransceiverConfig const& cacheTransceiverConfig); void setEnableTrtOverlap(bool enableTrtOverlap); + void setFailFastOnAttentionWindowTooLarge(bool failFastOnAttentionWindowTooLarge); private: friend class Serialization; @@ -1634,6 +1637,10 @@ class ExecutorConfig /// @brief Controls whether preparation and TRT engine execution should be overlapped. bool mEnableTrtOverlap{false}; + + /// @brief Controls whether to fail fast when attention window is too large to fit even a single sequence in the KV + /// cache. + bool mFailFastOnAttentionWindowTooLarge{false}; }; struct KVCacheCreatedData diff --git a/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp b/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp index 80418b2bc73..4a5ddb89286 100644 --- a/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp +++ b/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp @@ -296,7 +296,6 @@ TrtGptModelInflightBatching::TrtGptModelInflightBatching(std::shared_ptrgetBufferManager(), kvCacheConfig); - if (mModelConfig.useCrossAttention()) { TLLM_CHECK_WITH_INFO(kvCacheConfig.getCrossKvCacheFraction().has_value(), @@ -304,10 +303,11 @@ TrtGptModelInflightBatching::TrtGptModelInflightBatching(std::shared_ptr>; std::pair> -TrtGptModelInflightBatching::clampWindowSizesToFitAtLeastOneSequence(BlocksPerWindow const& blocksPerWindow) +TrtGptModelInflightBatching::clampWindowSizesToFitAtLeastOneSequence( + BlocksPerWindow const& blocksPerWindow, bool const failFastOnAttentionWindowTooLarge) { // At this point, we can only validate that the cheapest sequence in terms of kv-cache resources still fits. More // validation is needed on a per-request basis, once the prompt / output lengths and the actual beam width are @@ -591,6 +592,16 @@ TrtGptModelInflightBatching::clampWindowSizesToFitAtLeastOneSequence(BlocksPerWi } TLLM_LOG_WARNING("maxAttentionWindowVec too large to fit at least one sequence in kvCache. Old: %s, New: %s", common::vec2str(getMaxAttentionWindowVec()).c_str(), common::vec2str(newMaxAttentionWindowVec).c_str()); + + if (failFastOnAttentionWindowTooLarge) + { + throw std::runtime_error( + "Attention window too large to fit even a single sequence in the KV cache. Failing fast rather than " + "attempting an adjustment of the window sizes. " + "Old: " + + common::vec2str(getMaxAttentionWindowVec()) + ", New: " + common::vec2str(newMaxAttentionWindowVec)); + } + setMaxAttentionWindowVec(newMaxAttentionWindowVec); if (getMaxSequenceLen() > getMaxAttentionWindow()) { @@ -613,7 +624,7 @@ TrtGptModelInflightBatching::clampWindowSizesToFitAtLeastOneSequence(BlocksPerWi std::unique_ptr TrtGptModelInflightBatching::createKvCacheManager( KvCacheConfig const& kvCacheConfig, KvCacheType kvCacheType, uint64_t freePrimaryMemBytes, - uint64_t freeSecondaryMemBytes, size_t extraCostMemory) + uint64_t freeSecondaryMemBytes, size_t extraCostMemory, bool const failFastOnAttentionWindowTooLarge) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); bool isCrossAttention = kvCacheType == KvCacheType::kCROSS; @@ -657,7 +668,8 @@ std::unique_ptr TrtGptModelInflightBatching::c // and user also didn't provide maxAttentionWindow, which leads it to be equal to maxSeqLen if (kvCacheType == KvCacheType::kSELF) { - std::tie(blocksPerWindow, maxAttentionWindowVec) = clampWindowSizesToFitAtLeastOneSequence(blocksPerWindow); + std::tie(blocksPerWindow, maxAttentionWindowVec) + = clampWindowSizesToFitAtLeastOneSequence(blocksPerWindow, failFastOnAttentionWindowTooLarge); } kv_cache_manager::TempAttentionWindowInputs tempAttentionWindowInputs; diff --git a/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.h b/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.h index 6e9f1c8ce0f..28d1767525c 100644 --- a/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.h +++ b/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.h @@ -280,7 +280,8 @@ class TrtGptModelInflightBatching : public TrtGptModel void createBuffers(executor::DecodingConfig const& decodingConfig, std::optional> const& additionalModelOutputs); std::unique_ptr createKvCacheManager(KvCacheConfig const& kvCacheConfig, KvCacheType kvCacheType, - uint64_t freePrimaryMemBytes, uint64_t freeSecondaryMemBytes, size_t extraCostMemory); + uint64_t freePrimaryMemBytes, uint64_t freeSecondaryMemBytes, size_t extraCostMemory, + bool const failFastOnAttentionWindowTooLarge = false); void createRnnStateManager(); void createCustomAllReduceWorkspace(); void createRuntimePerfKnobsTensor(executor::ExtendedRuntimePerfKnobConfig const& extendedRuntimePerfKnobConfig); @@ -378,9 +379,11 @@ class TrtGptModelInflightBatching : public TrtGptModel /// window. /// /// @param blocksPerWindow map of window size to number of blocks. + /// @param failFastOnAttentionWindowTooLarge if true, the function will report a runtime error if the attention + /// window is too large to fit even a single sequence in the KV cache. /// @return pair of new blocks per window and new maxAttentionWindowVec [[nodiscard]] std::pair> clampWindowSizesToFitAtLeastOneSequence( - BlocksPerWindow const& blocksPerWindow); + BlocksPerWindow const& blocksPerWindow, bool const failFastOnAttentionWindowTooLarge = false); /// @brief Change the speculative decoding mode. void changeSpecDecMode(ScheduledRequests const& scheduledRequests); diff --git a/cpp/tensorrt_llm/executor/executorConfig.cpp b/cpp/tensorrt_llm/executor/executorConfig.cpp index 275d3605e70..2dff78280f5 100644 --- a/cpp/tensorrt_llm/executor/executorConfig.cpp +++ b/cpp/tensorrt_llm/executor/executorConfig.cpp @@ -34,7 +34,7 @@ ExecutorConfig::ExecutorConfig(SizeType32 maxBeamWidth, SchedulerConfig schedule std::optional specDecConfig, std::optional guidedDecodingConfig, std::optional> additionalModelOutputs, std::optional cacheTransceiverConfig, bool gatherGenerationLogits, - bool promptTableOffloading, bool enableTrtOverlap) + bool promptTableOffloading, bool enableTrtOverlap, bool failFastOnAttentionWindowTooLarge) : mMaxBeamWidth(maxBeamWidth) , mSchedulerConfig(std::move(schedulerConfig)) , mKvCacheConfig(std::move(kvCacheConfig)) @@ -63,6 +63,7 @@ ExecutorConfig::ExecutorConfig(SizeType32 maxBeamWidth, SchedulerConfig schedule , mGatherGenerationLogits(gatherGenerationLogits) , mPromptTableOffloading(promptTableOffloading) , mEnableTrtOverlap(enableTrtOverlap) + , mFailFastOnAttentionWindowTooLarge(failFastOnAttentionWindowTooLarge) { TLLM_CHECK(iterStatsMaxIterations >= 0); TLLM_CHECK(requestStatsMaxIterations >= 0); @@ -222,6 +223,11 @@ bool ExecutorConfig::getEnableTrtOverlap() const return mEnableTrtOverlap; } +bool ExecutorConfig::getFailFastOnAttentionWindowTooLarge() const +{ + return mFailFastOnAttentionWindowTooLarge; +} + // setters void ExecutorConfig::setMaxBeamWidth(SizeType32 maxBeamWidth) @@ -371,4 +377,9 @@ void ExecutorConfig::setEnableTrtOverlap(bool enableTrtOverlap) mEnableTrtOverlap = enableTrtOverlap; } +void ExecutorConfig::setFailFastOnAttentionWindowTooLarge(bool failFastOnAttentionWindowTooLarge) +{ + mFailFastOnAttentionWindowTooLarge = failFastOnAttentionWindowTooLarge; +} + } // namespace tensorrt_llm::executor diff --git a/cpp/tensorrt_llm/pybind/executor/executorConfig.cpp b/cpp/tensorrt_llm/pybind/executor/executorConfig.cpp index 87f32635866..ccbb21aab21 100644 --- a/cpp/tensorrt_llm/pybind/executor/executorConfig.cpp +++ b/cpp/tensorrt_llm/pybind/executor/executorConfig.cpp @@ -459,7 +459,7 @@ void initConfigBindings(pybind11::module_& m) c.getExtendedRuntimePerfKnobConfig(), c.getDebugConfig(), c.getRecvPollPeriodMs(), c.getMaxSeqIdleMicroseconds(), c.getSpecDecConfig(), c.getGuidedDecodingConfig(), c.getAdditionalModelOutputs(), c.getCacheTransceiverConfig(), c.getGatherGenerationLogits(), - c.getPromptTableOffloading(), c.getEnableTrtOverlap()); + c.getPromptTableOffloading(), c.getEnableTrtOverlap(), c.getFailFastOnAttentionWindowTooLarge()); auto pickle_tuple = py::make_tuple(cpp_states, py::getattr(self, "__dict__")); return pickle_tuple; }; @@ -472,7 +472,7 @@ void initConfigBindings(pybind11::module_& m) // Restore C++ data auto cpp_states = state[0].cast(); - if (cpp_states.size() != 28) + if (cpp_states.size() != 29) { throw std::runtime_error("Invalid cpp_states!"); } @@ -505,7 +505,8 @@ void initConfigBindings(pybind11::module_& m) cpp_states[24].cast>(), // CacheTransceiverConfig cpp_states[25].cast(), // GatherGenerationLogits cpp_states[26].cast(), // PromptTableOffloading - cpp_states[27].cast() // EnableTrtOverlap + cpp_states[27].cast(), // EnableTrtOverlap + cpp_states[28].cast() // FailFastOnAttentionWindowTooLarge ); auto py_state = state[1].cast(); @@ -542,7 +543,8 @@ void initConfigBindings(pybind11::module_& m) std::optional, // CacheTransceiverConfig bool, // GatherGenerationLogits bool, // PromptTableOffloading - bool // EnableTrtOverlap + bool, // EnableTrtOverlap + bool // FailFastOnAttentionWindowTooLarge >(), py::arg("max_beam_width") = 1, py::arg_v("scheduler_config", tle::SchedulerConfig(), "SchedulerConfig()"), py::arg_v("kv_cache_config", tle::KvCacheConfig(), "KvCacheConfig()"), @@ -563,7 +565,7 @@ void initConfigBindings(pybind11::module_& m) py::arg("spec_dec_config") = py::none(), py::arg("guided_decoding_config") = py::none(), py::arg("additional_model_outputs") = py::none(), py::arg("cache_transceiver_config") = py::none(), py::arg("gather_generation_logits") = false, py::arg("mm_embedding_offloading") = false, - py::arg("enable_trt_overlap") = false) + py::arg("enable_trt_overlap") = false, py::arg("fail_fast_on_attention_window_too_large") = false) .def_property("max_beam_width", &tle::ExecutorConfig::getMaxBeamWidth, &tle::ExecutorConfig::setMaxBeamWidth) .def_property("max_batch_size", &tle::ExecutorConfig::getMaxBatchSize, &tle::ExecutorConfig::setMaxBatchSize) .def_property("max_num_tokens", &tle::ExecutorConfig::getMaxNumTokens, &tle::ExecutorConfig::setMaxNumTokens) @@ -613,6 +615,9 @@ void initConfigBindings(pybind11::module_& m) &tle::ExecutorConfig::setPromptTableOffloading) .def_property( "enable_trt_overlap", &tle::ExecutorConfig::getEnableTrtOverlap, &tle::ExecutorConfig::setEnableTrtOverlap) + .def_property("fail_fast_on_attention_window_too_large", + &tle::ExecutorConfig::getFailFastOnAttentionWindowTooLarge, + &tle::ExecutorConfig::setFailFastOnAttentionWindowTooLarge) .def(py::pickle(executorConfigGetState, executorConfigSetState)); } diff --git a/examples/run.py b/examples/run.py index 3e46e9d9f6c..0f19b56d768 100755 --- a/examples/run.py +++ b/examples/run.py @@ -106,6 +106,13 @@ def parse_arguments(args=None): default=False, action='store_true', help="Run several 10 iterations to profile the inference latencies.") + parser.add_argument( + '--fail_fast_on_attention_window_too_large', + action='store_true', + default=False, + help= + 'Exit with runtime error when attention window is too large to fit even a single sequence in the KV cache.' + ) parser = add_common_args(parser) @@ -455,6 +462,8 @@ def main(args): gpu_weights_percent=args.gpu_weights_percent, max_output_len=args.max_output_len, enable_context_fmha_fp32_acc=args.enable_context_fmha_fp32_acc, + fail_fast_on_attention_window_too_large=args. + fail_fast_on_attention_window_too_large, ) if args.medusa_choices is not None: args.medusa_choices = ast.literal_eval(args.medusa_choices) @@ -549,6 +558,8 @@ def main(args): eagle_choices=args.eagle_choices, return_all_generated_tokens=args.return_all_generated_tokens, input_token_extra_ids=input_token_extra_ids, + fail_fast_on_attention_window_too_large=args. + fail_fast_on_attention_window_too_large, language_adapter_uids=args.language_task_uids) torch.cuda.synchronize() @@ -680,7 +691,9 @@ def main(args): return_dict=True, return_all_generated_tokens=args. return_all_generated_tokens, - input_token_extra_ids=input_token_extra_ids) + input_token_extra_ids=input_token_extra_ids, + fail_fast_on_attention_window_too_large=args. + fail_fast_on_attention_window_too_large) torch.cuda.synchronize() tensorrt_llm.profiler.stop("tmp") diff --git a/tensorrt_llm/commands/serve.py b/tensorrt_llm/commands/serve.py index 7de263ea89f..4f26be6579b 100644 --- a/tensorrt_llm/commands/serve.py +++ b/tensorrt_llm/commands/serve.py @@ -84,6 +84,7 @@ def get_llm_args(model: str, num_postprocess_workers: int = 0, trust_remote_code: bool = False, reasoning_parser: Optional[str] = None, + fail_fast_on_attention_window_too_large: bool = False, **llm_args_extra_dict: Any): if gpus_per_node is None: @@ -107,24 +108,44 @@ def get_llm_args(model: str, ) llm_args = { - "model": model, - "scheduler_config": scheduler_config, - "tokenizer": tokenizer, - "tensor_parallel_size": tensor_parallel_size, - "pipeline_parallel_size": pipeline_parallel_size, - "moe_expert_parallel_size": moe_expert_parallel_size, - "gpus_per_node": gpus_per_node, - "trust_remote_code": trust_remote_code, - "build_config": build_config, - "max_batch_size": max_batch_size, - "max_num_tokens": max_num_tokens, - "max_beam_width": max_beam_width, - "max_seq_len": max_seq_len, - "kv_cache_config": kv_cache_config, - "backend": backend if backend == "pytorch" else None, - "num_postprocess_workers": num_postprocess_workers, - "postprocess_tokenizer_dir": tokenizer or model, - "reasoning_parser": reasoning_parser, + "model": + model, + "scheduler_config": + scheduler_config, + "tokenizer": + tokenizer, + "tensor_parallel_size": + tensor_parallel_size, + "pipeline_parallel_size": + pipeline_parallel_size, + "moe_expert_parallel_size": + moe_expert_parallel_size, + "gpus_per_node": + gpus_per_node, + "trust_remote_code": + trust_remote_code, + "build_config": + build_config, + "max_batch_size": + max_batch_size, + "max_num_tokens": + max_num_tokens, + "max_beam_width": + max_beam_width, + "max_seq_len": + max_seq_len, + "kv_cache_config": + kv_cache_config, + "backend": + backend if backend == "pytorch" else None, + "num_postprocess_workers": + num_postprocess_workers, + "postprocess_tokenizer_dir": + tokenizer or model, + "reasoning_parser": + reasoning_parser, + "fail_fast_on_attention_window_too_large": + fail_fast_on_attention_window_too_large, } return llm_args, llm_args_extra_dict @@ -249,16 +270,23 @@ def launch_server(host: str, default=None, help="Server role. Specify this value only if running in disaggregated mode." ) -def serve(model: str, tokenizer: Optional[str], host: str, port: int, - log_level: str, backend: str, max_beam_width: int, - max_batch_size: int, max_num_tokens: int, max_seq_len: int, - tp_size: int, pp_size: int, ep_size: Optional[int], - cluster_size: Optional[int], gpus_per_node: Optional[int], - kv_cache_free_gpu_memory_fraction: float, - num_postprocess_workers: int, trust_remote_code: bool, - extra_llm_api_options: Optional[str], reasoning_parser: Optional[str], - metadata_server_config_file: Optional[str], - server_role: Optional[str]): +@click.option( + "--fail_fast_on_attention_window_too_large", + is_flag=True, + default=False, + help= + "Exit with runtime error when attention window is too large to fit even a single sequence in the KV cache." +) +def serve( + model: str, tokenizer: Optional[str], host: str, port: int, + log_level: str, backend: str, max_beam_width: int, max_batch_size: int, + max_num_tokens: int, max_seq_len: int, tp_size: int, pp_size: int, + ep_size: Optional[int], cluster_size: Optional[int], + gpus_per_node: Optional[int], kv_cache_free_gpu_memory_fraction: float, + num_postprocess_workers: int, trust_remote_code: bool, + extra_llm_api_options: Optional[str], reasoning_parser: Optional[str], + metadata_server_config_file: Optional[str], server_role: Optional[str], + fail_fast_on_attention_window_too_large: bool): """Running an OpenAI API compatible server MODEL: model name | HF checkpoint path | TensorRT engine path @@ -281,7 +309,9 @@ def serve(model: str, tokenizer: Optional[str], host: str, port: int, free_gpu_memory_fraction=kv_cache_free_gpu_memory_fraction, num_postprocess_workers=num_postprocess_workers, trust_remote_code=trust_remote_code, - reasoning_parser=reasoning_parser) + reasoning_parser=reasoning_parser, + fail_fast_on_attention_window_too_large= + fail_fast_on_attention_window_too_large) llm_args_extra_dict = {} if extra_llm_api_options is not None: diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index 934813aa4c4..dcf3ca92902 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -779,7 +779,9 @@ def _build_model(self): or tllm.BatchingType.INFLIGHT, max_batch_size=max_batch_size, max_num_tokens=max_num_tokens, - gather_generation_logits=self.args.gather_generation_logits) + gather_generation_logits=self.args.gather_generation_logits, + fail_fast_on_attention_window_too_large=getattr( + self.args, 'fail_fast_on_attention_window_too_large', False)) # also set executor_config.max_seq_len in TRT workflow, to deduce default max_tokens if max_seq_len is not None: @@ -920,7 +922,9 @@ def _build_model(self): or tllm.BatchingType.INFLIGHT, max_batch_size=max_batch_size, max_num_tokens=max_num_tokens, - gather_generation_logits=self.args.gather_generation_logits) + gather_generation_logits=self.args.gather_generation_logits, + fail_fast_on_attention_window_too_large=getattr( + self.args, 'fail_fast_on_attention_window_too_large', False)) if self.args.kv_cache_config is not None: self._executor_config.kv_cache_config = PybindMirror.maybe_to_pybind( diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 6614391b452..a563bc98f28 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -998,6 +998,12 @@ class BaseLlmArgs(BaseModel): description="The format to load the model.", json_schema_extra={"type": "Literal['auto', 'dummy']"}) + fail_fast_on_attention_window_too_large: bool = Field( + default=False, + description= + "Fail fast when attention window is too large to fit even a single sequence in the KV cache." + ) + # LoRA arguments enable_lora: bool = Field(default=False, description="Enable LoRA.") diff --git a/tensorrt_llm/runtime/model_runner.py b/tensorrt_llm/runtime/model_runner.py index a9f0fe8de40..ee35da3ef0e 100644 --- a/tensorrt_llm/runtime/model_runner.py +++ b/tensorrt_llm/runtime/model_runner.py @@ -646,6 +646,7 @@ def from_dir( gpu_weights_percent: float = 1, enable_context_fmha_fp32_acc: Optional[bool] = None, multi_block_mode: Optional[bool] = None, + fail_fast_on_attention_window_too_large: bool = False, ) -> 'ModelRunner': """ Create a ModelRunner instance from an engine directory. @@ -667,6 +668,9 @@ def from_dir( Stream to use. multi_block_mode (bool): Whether to distribute the work across multiple CUDA thread-blocks on the GPU for masked MHA kernel. + fail_fast_on_attention_window_too_large (bool): + Exit with runtime error when attention window is too large to fit even a single sequence in the KV cache. + Note: This parameter is only applicable to C++ runtime (ModelRunnerCpp). Returns: ModelRunner: An instance of ModelRunner. """ diff --git a/tensorrt_llm/runtime/model_runner_cpp.py b/tensorrt_llm/runtime/model_runner_cpp.py index 239c88d060f..b701f245f6f 100644 --- a/tensorrt_llm/runtime/model_runner_cpp.py +++ b/tensorrt_llm/runtime/model_runner_cpp.py @@ -124,6 +124,7 @@ def from_dir( gather_generation_logits: bool = False, use_variable_beam_width_search: bool = False, mm_embedding_offloading: bool = False, + fail_fast_on_attention_window_too_large: bool = False, ) -> 'ModelRunnerCpp': """ Create a ModelRunnerCpp instance from an engine directory. @@ -197,6 +198,8 @@ def from_dir( The mode to run the model-runner, Leader mode by default. gather_generation_logits (bool): Enable gathering generation logits. + fail_fast_on_attention_window_too_large (bool): + Whether to fail fast if the attention window(s) are too large to fit even a single sequence in the KVCache. Returns: ModelRunnerCpp: An instance of ModelRunnerCpp. """ @@ -398,6 +401,7 @@ def from_dir( trtllm_config.enable_chunked_context = enable_chunked_context trtllm_config.extended_runtime_perf_knob_config = extended_runtime_perf_knob_config trtllm_config.mm_embedding_offloading = mm_embedding_offloading + trtllm_config.fail_fast_on_attention_window_too_large = fail_fast_on_attention_window_too_large if is_orchestrator_mode: communication_mode = trtllm.CommunicationMode.ORCHESTRATOR path = str(Path(__file__).parent.parent / 'bin' / 'executorWorker') diff --git a/tests/unittest/api_stability/references/llm.yaml b/tests/unittest/api_stability/references/llm.yaml index 7e4867df50f..8284cc0a0db 100644 --- a/tests/unittest/api_stability/references/llm.yaml +++ b/tests/unittest/api_stability/references/llm.yaml @@ -53,6 +53,10 @@ methods: reasoning_parser: annotation: Optional[str] default: null + # Runtime behavior + fail_fast_on_attention_window_too_large: + annotation: bool + default: false garbage_collection_gen0_threshold: annotation: int default: 20000