diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 6f1bde47362..5406e7693d4 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -92,7 +92,16 @@ def Field(default: Any = ..., return PydanticField(default, **kwargs) -class CudaGraphConfig(BaseModel): +class StrictBaseModel(BaseModel): + """ + A base model that forbids arbitrary fields. + """ + + class Config: + extra = "forbid" # globally forbid arbitrary fields + + +class CudaGraphConfig(StrictBaseModel): """ Configuration for CUDA graphs. """ @@ -119,8 +128,40 @@ def validate_cuda_graph_max_batch_size(cls, v): "cuda_graph_config.max_batch_size must be non-negative") return v + @staticmethod + def _generate_cuda_graph_batch_sizes(max_batch_size: int, + enable_padding: bool) -> List[int]: + """Generate a list of batch sizes for CUDA graphs. + + Args: + max_batch_size: Maximum batch size to generate up to + enable_padding: Whether padding is enabled, which affects the batch size distribution + + Returns: + List of batch sizes to create CUDA graphs for + """ + if enable_padding: + batch_sizes = [1, 2, 4] + [i * 8 for i in range(1, 17)] + else: + batch_sizes = list(range(1, 32)) + [32, 64, 128] + + # Add powers of 2 up to max_batch_size + batch_sizes += [ + 2**i for i in range(8, math.floor(math.log(max_batch_size, 2))) + ] + + # Filter and sort batch sizes + batch_sizes = sorted( + [size for size in batch_sizes if size <= max_batch_size]) + + # Add max_batch_size if not already included + if max_batch_size != batch_sizes[-1]: + batch_sizes.append(max_batch_size) + + return batch_sizes + -class MoeConfig(BaseModel): +class MoeConfig(StrictBaseModel): """ Configuration for MoE. """ @@ -225,7 +266,7 @@ def to_mapping(self) -> Mapping: auto_parallel=self.auto_parallel) -class CalibConfig(BaseModel): +class CalibConfig(StrictBaseModel): """ Calibration configuration. """ @@ -277,7 +318,7 @@ class _ModelFormatKind(Enum): TLLM_ENGINE = 2 -class DecodingBaseConfig(BaseModel): +class DecodingBaseConfig(StrictBaseModel): max_draft_len: Optional[int] = None speculative_model_dir: Optional[Union[str, Path]] = None @@ -298,6 +339,7 @@ def from_dict(cls, data: dict): config_class = config_classes.get(decoding_type) if config_class is None: raise ValueError(f"Invalid decoding type: {decoding_type}") + data.pop("decoding_type") return config_class(**data) @@ -496,7 +538,7 @@ def mirror_pybind_fields(pybind_class): """ def decorator(cls): - assert issubclass(cls, BaseModel) + assert issubclass(cls, StrictBaseModel) # Get all non-private fields from the C++ class cpp_fields = PybindMirror.get_pybind_variable_fields(pybind_class) python_fields = set(cls.model_fields.keys()) @@ -597,7 +639,7 @@ def _to_pybind(self): @PybindMirror.mirror_pybind_fields(_DynamicBatchConfig) -class DynamicBatchConfig(BaseModel, PybindMirror): +class DynamicBatchConfig(StrictBaseModel, PybindMirror): """Dynamic batch configuration. Controls how batch size and token limits are dynamically adjusted at runtime. @@ -623,7 +665,7 @@ def _to_pybind(self): @PybindMirror.mirror_pybind_fields(_SchedulerConfig) -class SchedulerConfig(BaseModel, PybindMirror): +class SchedulerConfig(StrictBaseModel, PybindMirror): capacity_scheduler_policy: CapacitySchedulerPolicy = Field( default=CapacitySchedulerPolicy.GUARANTEED_NO_EVICT, description="The capacity scheduler policy to use") @@ -645,7 +687,7 @@ def _to_pybind(self): @PybindMirror.mirror_pybind_fields(_PeftCacheConfig) -class PeftCacheConfig(BaseModel, PybindMirror): +class PeftCacheConfig(StrictBaseModel, PybindMirror): """ Configuration for the PEFT cache. """ @@ -773,7 +815,7 @@ def supports_backend(self, backend: str) -> bool: @PybindMirror.mirror_pybind_fields(_KvCacheConfig) -class KvCacheConfig(BaseModel, PybindMirror): +class KvCacheConfig(StrictBaseModel, PybindMirror): """ Configuration for the KV cache. """ @@ -856,7 +898,7 @@ def _to_pybind(self): @PybindMirror.mirror_pybind_fields(_ExtendedRuntimePerfKnobConfig) -class ExtendedRuntimePerfKnobConfig(BaseModel, PybindMirror): +class ExtendedRuntimePerfKnobConfig(StrictBaseModel, PybindMirror): """ Configuration for extended runtime performance knobs. """ @@ -887,7 +929,7 @@ def _to_pybind(self): @PybindMirror.mirror_pybind_fields(_CacheTransceiverConfig) -class CacheTransceiverConfig(BaseModel, PybindMirror): +class CacheTransceiverConfig(StrictBaseModel, PybindMirror): """ Configuration for the cache transceiver. """ @@ -947,7 +989,7 @@ def model_name(self) -> Union[str, Path]: return self.model if isinstance(self.model, str) else None -class BaseLlmArgs(BaseModel): +class BaseLlmArgs(StrictBaseModel): """ Base class for both TorchLlmArgs and TrtLlmArgs. It contains all the arguments that are common to both. """ @@ -1354,7 +1396,8 @@ def init_build_config(self): """ Creating a default BuildConfig if none is provided """ - if self.build_config is None: + build_config = getattr(self, "build_config", None) + if build_config is None: kwargs = {} if self.max_batch_size: kwargs["max_batch_size"] = self.max_batch_size @@ -1367,10 +1410,10 @@ def init_build_config(self): if self.max_input_len: kwargs["max_input_len"] = self.max_input_len self.build_config = BuildConfig(**kwargs) - - assert isinstance( - self.build_config, BuildConfig - ), f"build_config is not initialized: {self.build_config}" + else: + assert isinstance( + build_config, + BuildConfig), f"build_config is not initialized: {build_config}" return self @model_validator(mode="after") @@ -1813,7 +1856,7 @@ class LoadFormat(Enum): DUMMY = 1 -class TorchCompileConfig(BaseModel): +class TorchCompileConfig(StrictBaseModel): """ Configuration for torch.compile. """ @@ -2049,38 +2092,6 @@ def validate_checkpoint_format(self): return self - @staticmethod - def _generate_cuda_graph_batch_sizes(max_batch_size: int, - enable_padding: bool) -> List[int]: - """Generate a list of batch sizes for CUDA graphs. - - Args: - max_batch_size: Maximum batch size to generate up to - enable_padding: Whether padding is enabled, which affects the batch size distribution - - Returns: - List of batch sizes to create CUDA graphs for - """ - if enable_padding: - batch_sizes = [1, 2, 4] + [i * 8 for i in range(1, 17)] - else: - batch_sizes = list(range(1, 32)) + [32, 64, 128] - - # Add powers of 2 up to max_batch_size - batch_sizes += [ - 2**i for i in range(8, math.floor(math.log(max_batch_size, 2))) - ] - - # Filter and sort batch sizes - batch_sizes = sorted( - [size for size in batch_sizes if size <= max_batch_size]) - - # Add max_batch_size if not already included - if max_batch_size != batch_sizes[-1]: - batch_sizes.append(max_batch_size) - - return batch_sizes - @model_validator(mode="after") def validate_load_balancer(self) -> 'TorchLlmArgs': from .._torch import MoeLoadBalancerConfig @@ -2117,7 +2128,7 @@ def validate_cuda_graph_config(self) -> 'TorchLlmArgs': if config.batch_sizes: config.batch_sizes = sorted(config.batch_sizes) if config.max_batch_size != 0: - if config.batch_sizes != self._generate_cuda_graph_batch_sizes( + if config.batch_sizes != CudaGraphConfig._generate_cuda_graph_batch_sizes( config.max_batch_size, config.enable_padding): raise ValueError( "Please don't set both cuda_graph_config.batch_sizes " @@ -2129,7 +2140,7 @@ def validate_cuda_graph_config(self) -> 'TorchLlmArgs': config.max_batch_size = max(config.batch_sizes) else: max_batch_size = config.max_batch_size or 128 - generated_sizes = self._generate_cuda_graph_batch_sizes( + generated_sizes = CudaGraphConfig._generate_cuda_graph_batch_sizes( max_batch_size, config.enable_padding) config.batch_sizes = generated_sizes config.max_batch_size = max_batch_size diff --git a/tests/unittest/llmapi/test_llm_args.py b/tests/unittest/llmapi/test_llm_args.py index 801a2bf12a9..d6990ac745c 100644 --- a/tests/unittest/llmapi/test_llm_args.py +++ b/tests/unittest/llmapi/test_llm_args.py @@ -61,7 +61,6 @@ def test_update_llm_args_with_extra_dict_with_speculative_config(self): decoding_type: Lookahead max_window_size: 4 max_ngram_size: 3 - verification_set_size: 4 """ dict_content = self._yaml_to_dict(yaml_content) @@ -473,3 +472,229 @@ def test_build_config_from_engine(self): assert args.max_num_tokens == 16 assert args.max_batch_size == 4 + + +class TestStrictBaseModelArbitraryArgs: + """Test that StrictBaseModel prevents arbitrary arguments from being accepted.""" + + def test_cuda_graph_config_arbitrary_args(self): + """Test that CudaGraphConfig rejects arbitrary arguments.""" + # Valid arguments should work + config = CudaGraphConfig(batch_sizes=[1, 2, 4], max_batch_size=8) + assert config.batch_sizes == [1, 2, 4] + assert config.max_batch_size == 8 + + # Arbitrary arguments should be rejected + with pytest.raises( + pydantic_core._pydantic_core.ValidationError) as exc_info: + CudaGraphConfig(batch_sizes=[1, 2, 4], invalid_arg="should_fail") + assert "invalid_arg" in str(exc_info.value) + + def test_moe_config_arbitrary_args(self): + """Test that MoeConfig rejects arbitrary arguments.""" + # Valid arguments should work + config = MoeConfig(backend="CUTLASS", max_num_tokens=1024) + assert config.backend == "CUTLASS" + assert config.max_num_tokens == 1024 + + # Arbitrary arguments should be rejected + with pytest.raises( + pydantic_core._pydantic_core.ValidationError) as exc_info: + MoeConfig(backend="CUTLASS", unknown_field="should_fail") + assert "unknown_field" in str(exc_info.value) + + def test_calib_config_arbitrary_args(self): + """Test that CalibConfig rejects arbitrary arguments.""" + # Valid arguments should work + config = CalibConfig(device="cuda", calib_batches=512) + assert config.device == "cuda" + assert config.calib_batches == 512 + + # Arbitrary arguments should be rejected + with pytest.raises( + pydantic_core._pydantic_core.ValidationError) as exc_info: + CalibConfig(device="cuda", extra_field="should_fail") + assert "extra_field" in str(exc_info.value) + + def test_decoding_base_config_arbitrary_args(self): + """Test that DecodingBaseConfig rejects arbitrary arguments.""" + # Valid arguments should work + config = DecodingBaseConfig(max_draft_len=10) + assert config.max_draft_len == 10 + + # Arbitrary arguments should be rejected + with pytest.raises( + pydantic_core._pydantic_core.ValidationError) as exc_info: + DecodingBaseConfig(max_draft_len=10, random_field="should_fail") + assert "random_field" in str(exc_info.value) + + def test_dynamic_batch_config_arbitrary_args(self): + """Test that DynamicBatchConfig rejects arbitrary arguments.""" + # Valid arguments should work + config = DynamicBatchConfig(enable_batch_size_tuning=True, + enable_max_num_tokens_tuning=True, + dynamic_batch_moving_average_window=8) + assert config.enable_batch_size_tuning == True + + # Arbitrary arguments should be rejected + with pytest.raises( + pydantic_core._pydantic_core.ValidationError) as exc_info: + DynamicBatchConfig(enable_batch_size_tuning=True, + enable_max_num_tokens_tuning=True, + dynamic_batch_moving_average_window=8, + fake_param="should_fail") + assert "fake_param" in str(exc_info.value) + + def test_scheduler_config_arbitrary_args(self): + """Test that SchedulerConfig rejects arbitrary arguments.""" + # Valid arguments should work + config = SchedulerConfig( + capacity_scheduler_policy=CapacitySchedulerPolicy.MAX_UTILIZATION) + assert config.capacity_scheduler_policy == CapacitySchedulerPolicy.MAX_UTILIZATION + + # Arbitrary arguments should be rejected + with pytest.raises( + pydantic_core._pydantic_core.ValidationError) as exc_info: + SchedulerConfig(capacity_scheduler_policy=CapacitySchedulerPolicy. + MAX_UTILIZATION, + invalid_option="should_fail") + assert "invalid_option" in str(exc_info.value) + + def test_peft_cache_config_arbitrary_args(self): + """Test that PeftCacheConfig rejects arbitrary arguments.""" + # Valid arguments should work + config = PeftCacheConfig(num_host_module_layer=1, + num_device_module_layer=1) + assert config.num_host_module_layer == 1 + assert config.num_device_module_layer == 1 + + # Arbitrary arguments should be rejected + with pytest.raises( + pydantic_core._pydantic_core.ValidationError) as exc_info: + PeftCacheConfig(num_host_module_layer=1, + unexpected_field="should_fail") + assert "unexpected_field" in str(exc_info.value) + + def test_kv_cache_config_arbitrary_args(self): + """Test that KvCacheConfig rejects arbitrary arguments.""" + # Valid arguments should work + config = KvCacheConfig(enable_block_reuse=True, max_tokens=1024) + assert config.enable_block_reuse == True + assert config.max_tokens == 1024 + + # Arbitrary arguments should be rejected + with pytest.raises( + pydantic_core._pydantic_core.ValidationError) as exc_info: + KvCacheConfig(enable_block_reuse=True, + non_existent_field="should_fail") + assert "non_existent_field" in str(exc_info.value) + + def test_extended_runtime_perf_knob_config_arbitrary_args(self): + """Test that ExtendedRuntimePerfKnobConfig rejects arbitrary arguments.""" + # Valid arguments should work + config = ExtendedRuntimePerfKnobConfig(multi_block_mode=True, + cuda_graph_mode=False) + assert config.multi_block_mode == True + assert config.cuda_graph_mode == False + + # Arbitrary arguments should be rejected + with pytest.raises( + pydantic_core._pydantic_core.ValidationError) as exc_info: + ExtendedRuntimePerfKnobConfig(multi_block_mode=True, + bogus_setting="should_fail") + assert "bogus_setting" in str(exc_info.value) + + def test_cache_transceiver_config_arbitrary_args(self): + """Test that CacheTransceiverConfig rejects arbitrary arguments.""" + # Valid arguments should work + config = CacheTransceiverConfig(backend="ucx", + max_tokens_in_buffer=1024) + assert config.backend == "ucx" + assert config.max_tokens_in_buffer == 1024 + + # Arbitrary arguments should be rejected + with pytest.raises( + pydantic_core._pydantic_core.ValidationError) as exc_info: + CacheTransceiverConfig(backend="ucx", invalid_config="should_fail") + assert "invalid_config" in str(exc_info.value) + + def test_torch_compile_config_arbitrary_args(self): + """Test that TorchCompileConfig rejects arbitrary arguments.""" + # Valid arguments should work + config = TorchCompileConfig(enable_fullgraph=True, + enable_inductor=False) + assert config.enable_fullgraph == True + assert config.enable_inductor == False + + # Arbitrary arguments should be rejected + with pytest.raises( + pydantic_core._pydantic_core.ValidationError) as exc_info: + TorchCompileConfig(enable_fullgraph=True, + invalid_flag="should_fail") + assert "invalid_flag" in str(exc_info.value) + + def test_trt_llm_args_arbitrary_args(self): + """Test that TrtLlmArgs rejects arbitrary arguments.""" + # Valid arguments should work + args = TrtLlmArgs(model=llama_model_path, max_batch_size=8) + assert args.model == llama_model_path + assert args.max_batch_size == 8 + + # Arbitrary arguments should be rejected + with pytest.raises( + pydantic_core._pydantic_core.ValidationError) as exc_info: + TrtLlmArgs(model=llama_model_path, invalid_setting="should_fail") + assert "invalid_setting" in str(exc_info.value) + + def test_torch_llm_args_arbitrary_args(self): + """Test that TorchLlmArgs rejects arbitrary arguments.""" + # Valid arguments should work + args = TorchLlmArgs(model=llama_model_path, max_batch_size=8) + assert args.model == llama_model_path + assert args.max_batch_size == 8 + + # Arbitrary arguments should be rejected + with pytest.raises( + pydantic_core._pydantic_core.ValidationError) as exc_info: + TorchLlmArgs(model=llama_model_path, + unsupported_option="should_fail") + assert "unsupported_option" in str(exc_info.value) + + def test_nested_config_arbitrary_args(self): + """Test that nested configurations also reject arbitrary arguments.""" + # Test with nested KvCacheConfig + with pytest.raises( + pydantic_core._pydantic_core.ValidationError) as exc_info: + KvCacheConfig(enable_block_reuse=True, + max_tokens=1024, + invalid_nested_field="should_fail") + assert "invalid_nested_field" in str(exc_info.value) + + # Test with nested SchedulerConfig + with pytest.raises( + pydantic_core._pydantic_core.ValidationError) as exc_info: + SchedulerConfig(capacity_scheduler_policy=CapacitySchedulerPolicy. + MAX_UTILIZATION, + nested_invalid_field="should_fail") + assert "nested_invalid_field" in str(exc_info.value) + + def test_strict_base_model_inheritance(self): + """Test that StrictBaseModel properly forbids extra fields.""" + # Verify that StrictBaseModel is properly configured + assert StrictBaseModel.model_config.get("extra") == "forbid" + + # Test that a simple StrictBaseModel instance rejects arbitrary fields + class TestConfig(StrictBaseModel): + field1: str = "default" + field2: int = 42 + + # Valid configuration should work + config = TestConfig(field1="test", field2=100) + assert config.field1 == "test" + assert config.field2 == 100 + + # Arbitrary field should be rejected + with pytest.raises( + pydantic_core._pydantic_core.ValidationError) as exc_info: + TestConfig(field1="test", field2=100, extra_field="should_fail") + assert "extra_field" in str(exc_info.value)