Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions tensorrt_llm/llmapi/llm_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1937,15 +1937,9 @@ def update_llm_args_with_extra_dict(
"quant_config": QuantConfig,
"calib_config": CalibConfig,
"build_config": BuildConfig,
"kv_cache_config": KvCacheConfig,
"decoding_config": DecodingConfig,
"enable_build_cache": BuildCacheConfig,
"peft_cache_config": PeftCacheConfig,
"scheduler_config": SchedulerConfig,
"speculative_config": DecodingBaseConfig,
"batching_type": BatchingType,
"extended_runtime_perf_knob_config": ExtendedRuntimePerfKnobConfig,
"cache_transceiver_config": CacheTransceiverConfig,
"lora_config": LoraConfig,
}
for field_name, field_type in field_mapping.items():
Expand Down
111 changes: 92 additions & 19 deletions tests/unittest/llmapi/test_llm_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,27 +40,100 @@ def test_LookaheadDecodingConfig():
assert pybind_config.max_verification_set_size == 4


def test_update_llm_args_with_extra_dict_with_speculative_config():
yaml_content = """
class TestYaml:

def _yaml_to_dict(self, yaml_content: str) -> dict:
with tempfile.NamedTemporaryFile(delete=False) as f:
f.write(yaml_content.encode('utf-8'))
f.flush()
f.seek(0)
dict_content = yaml.safe_load(f)
return dict_content

def test_update_llm_args_with_extra_dict_with_speculative_config(self):
yaml_content = """
speculative_config:
decoding_type: Lookahead
max_window_size: 4
max_ngram_size: 3
verification_set_size: 4
decoding_type: Lookahead
max_window_size: 4
max_ngram_size: 3
verification_set_size: 4
"""
dict_content = self._yaml_to_dict(yaml_content)

llm_args = TrtLlmArgs(model=llama_model_path)
llm_args_dict = update_llm_args_with_extra_dict(llm_args.to_dict(),
dict_content)
llm_args = TrtLlmArgs(**llm_args_dict)
assert llm_args.speculative_config.max_window_size == 4
assert llm_args.speculative_config.max_ngram_size == 3
assert llm_args.speculative_config.max_verification_set_size == 4

def test_llm_args_with_invalid_yaml(self):
yaml_content = """
pytorch_backend_config: # this is deprecated
max_num_tokens: 1
max_seq_len: 1
"""
with tempfile.NamedTemporaryFile(delete=False) as f:
f.write(yaml_content.encode('utf-8'))
f.flush()
f.seek(0)
dict_content = yaml.safe_load(f)

llm_args = TrtLlmArgs(model=llama_model_path)
llm_args_dict = update_llm_args_with_extra_dict(llm_args.to_dict(),
dict_content)
llm_args = TrtLlmArgs(**llm_args_dict)
assert llm_args.speculative_config.max_window_size == 4
assert llm_args.speculative_config.max_ngram_size == 3
assert llm_args.speculative_config.max_verification_set_size == 4
dict_content = self._yaml_to_dict(yaml_content)

llm_args = TrtLlmArgs(model=llama_model_path)
llm_args_dict = update_llm_args_with_extra_dict(llm_args.to_dict(),
dict_content)
with pytest.raises(ValueError):
llm_args = TrtLlmArgs(**llm_args_dict)

def test_llm_args_with_build_config(self):
# build_config isn't a Pydantic
yaml_content = """
build_config:
max_beam_width: 4
max_batch_size: 8
max_num_tokens: 256
"""
dict_content = self._yaml_to_dict(yaml_content)

llm_args = TrtLlmArgs(model=llama_model_path)
llm_args_dict = update_llm_args_with_extra_dict(llm_args.to_dict(),
dict_content)
llm_args = TrtLlmArgs(**llm_args_dict)
assert llm_args.build_config.max_beam_width == 4
assert llm_args.build_config.max_batch_size == 8
assert llm_args.build_config.max_num_tokens == 256

def test_llm_args_with_kvcache_config(self):
yaml_content = """
kv_cache_config:
enable_block_reuse: True
max_tokens: 1024
max_attention_window: [1024, 1024, 1024]
"""
dict_content = self._yaml_to_dict(yaml_content)

llm_args = TrtLlmArgs(model=llama_model_path)
llm_args_dict = update_llm_args_with_extra_dict(llm_args.to_dict(),
dict_content)
llm_args = TrtLlmArgs(**llm_args_dict)
assert llm_args.kv_cache_config.enable_block_reuse == True
assert llm_args.kv_cache_config.max_tokens == 1024
assert llm_args.kv_cache_config.max_attention_window == [
1024, 1024, 1024
]

def test_llm_args_with_pydantic_options(self):
yaml_content = """
max_batch_size: 16
max_num_tokens: 256
max_seq_len: 128
"""
dict_content = self._yaml_to_dict(yaml_content)

llm_args = TrtLlmArgs(model=llama_model_path)
llm_args_dict = update_llm_args_with_extra_dict(llm_args.to_dict(),
dict_content)
llm_args = TrtLlmArgs(**llm_args_dict)
assert llm_args.max_batch_size == 16
assert llm_args.max_num_tokens == 256
assert llm_args.max_seq_len == 128


def check_defaults(py_config_cls, pybind_config_cls):
Expand Down