diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index c5ba1f18c1e..0825e717cd7 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -1937,15 +1937,9 @@ def update_llm_args_with_extra_dict( "quant_config": QuantConfig, "calib_config": CalibConfig, "build_config": BuildConfig, - "kv_cache_config": KvCacheConfig, "decoding_config": DecodingConfig, "enable_build_cache": BuildCacheConfig, - "peft_cache_config": PeftCacheConfig, - "scheduler_config": SchedulerConfig, "speculative_config": DecodingBaseConfig, - "batching_type": BatchingType, - "extended_runtime_perf_knob_config": ExtendedRuntimePerfKnobConfig, - "cache_transceiver_config": CacheTransceiverConfig, "lora_config": LoraConfig, } for field_name, field_type in field_mapping.items(): diff --git a/tests/unittest/llmapi/test_llm_args.py b/tests/unittest/llmapi/test_llm_args.py index 8424f9a4e77..b2eb9e8d8cd 100644 --- a/tests/unittest/llmapi/test_llm_args.py +++ b/tests/unittest/llmapi/test_llm_args.py @@ -40,27 +40,100 @@ def test_LookaheadDecodingConfig(): assert pybind_config.max_verification_set_size == 4 -def test_update_llm_args_with_extra_dict_with_speculative_config(): - yaml_content = """ +class TestYaml: + + def _yaml_to_dict(self, yaml_content: str) -> dict: + with tempfile.NamedTemporaryFile(delete=False) as f: + f.write(yaml_content.encode('utf-8')) + f.flush() + f.seek(0) + dict_content = yaml.safe_load(f) + return dict_content + + def test_update_llm_args_with_extra_dict_with_speculative_config(self): + yaml_content = """ speculative_config: - decoding_type: Lookahead - max_window_size: 4 - max_ngram_size: 3 - verification_set_size: 4 + decoding_type: Lookahead + max_window_size: 4 + max_ngram_size: 3 + verification_set_size: 4 + """ + dict_content = self._yaml_to_dict(yaml_content) + + llm_args = TrtLlmArgs(model=llama_model_path) + llm_args_dict = update_llm_args_with_extra_dict(llm_args.to_dict(), + dict_content) + llm_args = TrtLlmArgs(**llm_args_dict) + assert llm_args.speculative_config.max_window_size == 4 + assert llm_args.speculative_config.max_ngram_size == 3 + assert llm_args.speculative_config.max_verification_set_size == 4 + + def test_llm_args_with_invalid_yaml(self): + yaml_content = """ +pytorch_backend_config: # this is deprecated + max_num_tokens: 1 + max_seq_len: 1 """ - with tempfile.NamedTemporaryFile(delete=False) as f: - f.write(yaml_content.encode('utf-8')) - f.flush() - f.seek(0) - dict_content = yaml.safe_load(f) - - llm_args = TrtLlmArgs(model=llama_model_path) - llm_args_dict = update_llm_args_with_extra_dict(llm_args.to_dict(), - dict_content) - llm_args = TrtLlmArgs(**llm_args_dict) - assert llm_args.speculative_config.max_window_size == 4 - assert llm_args.speculative_config.max_ngram_size == 3 - assert llm_args.speculative_config.max_verification_set_size == 4 + dict_content = self._yaml_to_dict(yaml_content) + + llm_args = TrtLlmArgs(model=llama_model_path) + llm_args_dict = update_llm_args_with_extra_dict(llm_args.to_dict(), + dict_content) + with pytest.raises(ValueError): + llm_args = TrtLlmArgs(**llm_args_dict) + + def test_llm_args_with_build_config(self): + # build_config isn't a Pydantic + yaml_content = """ +build_config: + max_beam_width: 4 + max_batch_size: 8 + max_num_tokens: 256 + """ + dict_content = self._yaml_to_dict(yaml_content) + + llm_args = TrtLlmArgs(model=llama_model_path) + llm_args_dict = update_llm_args_with_extra_dict(llm_args.to_dict(), + dict_content) + llm_args = TrtLlmArgs(**llm_args_dict) + assert llm_args.build_config.max_beam_width == 4 + assert llm_args.build_config.max_batch_size == 8 + assert llm_args.build_config.max_num_tokens == 256 + + def test_llm_args_with_kvcache_config(self): + yaml_content = """ +kv_cache_config: + enable_block_reuse: True + max_tokens: 1024 + max_attention_window: [1024, 1024, 1024] + """ + dict_content = self._yaml_to_dict(yaml_content) + + llm_args = TrtLlmArgs(model=llama_model_path) + llm_args_dict = update_llm_args_with_extra_dict(llm_args.to_dict(), + dict_content) + llm_args = TrtLlmArgs(**llm_args_dict) + assert llm_args.kv_cache_config.enable_block_reuse == True + assert llm_args.kv_cache_config.max_tokens == 1024 + assert llm_args.kv_cache_config.max_attention_window == [ + 1024, 1024, 1024 + ] + + def test_llm_args_with_pydantic_options(self): + yaml_content = """ +max_batch_size: 16 +max_num_tokens: 256 +max_seq_len: 128 + """ + dict_content = self._yaml_to_dict(yaml_content) + + llm_args = TrtLlmArgs(model=llama_model_path) + llm_args_dict = update_llm_args_with_extra_dict(llm_args.to_dict(), + dict_content) + llm_args = TrtLlmArgs(**llm_args_dict) + assert llm_args.max_batch_size == 16 + assert llm_args.max_num_tokens == 256 + assert llm_args.max_seq_len == 128 def check_defaults(py_config_cls, pybind_config_cls):