Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 63 additions & 52 deletions tensorrt_llm/llmapi/llm_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,16 @@ def Field(default: Any = ...,
return PydanticField(default, **kwargs)


class CudaGraphConfig(BaseModel):
class StrictBaseModel(BaseModel):
"""
A base model that forbids arbitrary fields.
"""

class Config:
extra = "forbid" # globally forbid arbitrary fields


class CudaGraphConfig(StrictBaseModel):
"""
Configuration for CUDA graphs.
"""
Expand All @@ -119,8 +128,40 @@ def validate_cuda_graph_max_batch_size(cls, v):
"cuda_graph_config.max_batch_size must be non-negative")
return v

@staticmethod
def _generate_cuda_graph_batch_sizes(max_batch_size: int,
enable_padding: bool) -> List[int]:
"""Generate a list of batch sizes for CUDA graphs.

Args:
max_batch_size: Maximum batch size to generate up to
enable_padding: Whether padding is enabled, which affects the batch size distribution

Returns:
List of batch sizes to create CUDA graphs for
"""
if enable_padding:
batch_sizes = [1, 2, 4] + [i * 8 for i in range(1, 17)]
else:
batch_sizes = list(range(1, 32)) + [32, 64, 128]

# Add powers of 2 up to max_batch_size
batch_sizes += [
2**i for i in range(8, math.floor(math.log(max_batch_size, 2)))
]

# Filter and sort batch sizes
batch_sizes = sorted(
[size for size in batch_sizes if size <= max_batch_size])

# Add max_batch_size if not already included
if max_batch_size != batch_sizes[-1]:
batch_sizes.append(max_batch_size)

return batch_sizes


class MoeConfig(BaseModel):
class MoeConfig(StrictBaseModel):
"""
Configuration for MoE.
"""
Expand Down Expand Up @@ -225,7 +266,7 @@ def to_mapping(self) -> Mapping:
auto_parallel=self.auto_parallel)


class CalibConfig(BaseModel):
class CalibConfig(StrictBaseModel):
"""
Calibration configuration.
"""
Expand Down Expand Up @@ -277,7 +318,7 @@ class _ModelFormatKind(Enum):
TLLM_ENGINE = 2


class DecodingBaseConfig(BaseModel):
class DecodingBaseConfig(StrictBaseModel):
max_draft_len: Optional[int] = None
speculative_model_dir: Optional[Union[str, Path]] = None

Expand All @@ -298,6 +339,7 @@ def from_dict(cls, data: dict):
config_class = config_classes.get(decoding_type)
if config_class is None:
raise ValueError(f"Invalid decoding type: {decoding_type}")
data.pop("decoding_type")

return config_class(**data)

Expand Down Expand Up @@ -496,7 +538,7 @@ def mirror_pybind_fields(pybind_class):
"""

def decorator(cls):
assert issubclass(cls, BaseModel)
assert issubclass(cls, StrictBaseModel)
# Get all non-private fields from the C++ class
cpp_fields = PybindMirror.get_pybind_variable_fields(pybind_class)
python_fields = set(cls.model_fields.keys())
Expand Down Expand Up @@ -597,7 +639,7 @@ def _to_pybind(self):


@PybindMirror.mirror_pybind_fields(_DynamicBatchConfig)
class DynamicBatchConfig(BaseModel, PybindMirror):
class DynamicBatchConfig(StrictBaseModel, PybindMirror):
"""Dynamic batch configuration.

Controls how batch size and token limits are dynamically adjusted at runtime.
Expand All @@ -623,7 +665,7 @@ def _to_pybind(self):


@PybindMirror.mirror_pybind_fields(_SchedulerConfig)
class SchedulerConfig(BaseModel, PybindMirror):
class SchedulerConfig(StrictBaseModel, PybindMirror):
capacity_scheduler_policy: CapacitySchedulerPolicy = Field(
default=CapacitySchedulerPolicy.GUARANTEED_NO_EVICT,
description="The capacity scheduler policy to use")
Expand All @@ -645,7 +687,7 @@ def _to_pybind(self):


@PybindMirror.mirror_pybind_fields(_PeftCacheConfig)
class PeftCacheConfig(BaseModel, PybindMirror):
class PeftCacheConfig(StrictBaseModel, PybindMirror):
"""
Configuration for the PEFT cache.
"""
Expand Down Expand Up @@ -773,7 +815,7 @@ def supports_backend(self, backend: str) -> bool:


@PybindMirror.mirror_pybind_fields(_KvCacheConfig)
class KvCacheConfig(BaseModel, PybindMirror):
class KvCacheConfig(StrictBaseModel, PybindMirror):
"""
Configuration for the KV cache.
"""
Expand Down Expand Up @@ -856,7 +898,7 @@ def _to_pybind(self):


@PybindMirror.mirror_pybind_fields(_ExtendedRuntimePerfKnobConfig)
class ExtendedRuntimePerfKnobConfig(BaseModel, PybindMirror):
class ExtendedRuntimePerfKnobConfig(StrictBaseModel, PybindMirror):
"""
Configuration for extended runtime performance knobs.
"""
Expand Down Expand Up @@ -887,7 +929,7 @@ def _to_pybind(self):


@PybindMirror.mirror_pybind_fields(_CacheTransceiverConfig)
class CacheTransceiverConfig(BaseModel, PybindMirror):
class CacheTransceiverConfig(StrictBaseModel, PybindMirror):
"""
Configuration for the cache transceiver.
"""
Expand Down Expand Up @@ -947,7 +989,7 @@ def model_name(self) -> Union[str, Path]:
return self.model if isinstance(self.model, str) else None


class BaseLlmArgs(BaseModel):
class BaseLlmArgs(StrictBaseModel):
"""
Base class for both TorchLlmArgs and TrtLlmArgs. It contains all the arguments that are common to both.
"""
Expand Down Expand Up @@ -1354,7 +1396,8 @@ def init_build_config(self):
"""
Creating a default BuildConfig if none is provided
"""
if self.build_config is None:
build_config = getattr(self, "build_config", None)
if build_config is None:
kwargs = {}
if self.max_batch_size:
kwargs["max_batch_size"] = self.max_batch_size
Expand All @@ -1367,10 +1410,10 @@ def init_build_config(self):
if self.max_input_len:
kwargs["max_input_len"] = self.max_input_len
self.build_config = BuildConfig(**kwargs)

assert isinstance(
self.build_config, BuildConfig
), f"build_config is not initialized: {self.build_config}"
else:
assert isinstance(
build_config,
BuildConfig), f"build_config is not initialized: {build_config}"
return self

@model_validator(mode="after")
Expand Down Expand Up @@ -1813,7 +1856,7 @@ class LoadFormat(Enum):
DUMMY = 1


class TorchCompileConfig(BaseModel):
class TorchCompileConfig(StrictBaseModel):
"""
Configuration for torch.compile.
"""
Expand Down Expand Up @@ -2049,38 +2092,6 @@ def validate_checkpoint_format(self):

return self

@staticmethod
def _generate_cuda_graph_batch_sizes(max_batch_size: int,
enable_padding: bool) -> List[int]:
"""Generate a list of batch sizes for CUDA graphs.

Args:
max_batch_size: Maximum batch size to generate up to
enable_padding: Whether padding is enabled, which affects the batch size distribution

Returns:
List of batch sizes to create CUDA graphs for
"""
if enable_padding:
batch_sizes = [1, 2, 4] + [i * 8 for i in range(1, 17)]
else:
batch_sizes = list(range(1, 32)) + [32, 64, 128]

# Add powers of 2 up to max_batch_size
batch_sizes += [
2**i for i in range(8, math.floor(math.log(max_batch_size, 2)))
]

# Filter and sort batch sizes
batch_sizes = sorted(
[size for size in batch_sizes if size <= max_batch_size])

# Add max_batch_size if not already included
if max_batch_size != batch_sizes[-1]:
batch_sizes.append(max_batch_size)

return batch_sizes

@model_validator(mode="after")
def validate_load_balancer(self) -> 'TorchLlmArgs':
from .._torch import MoeLoadBalancerConfig
Expand Down Expand Up @@ -2117,7 +2128,7 @@ def validate_cuda_graph_config(self) -> 'TorchLlmArgs':
if config.batch_sizes:
config.batch_sizes = sorted(config.batch_sizes)
if config.max_batch_size != 0:
if config.batch_sizes != self._generate_cuda_graph_batch_sizes(
if config.batch_sizes != CudaGraphConfig._generate_cuda_graph_batch_sizes(
config.max_batch_size, config.enable_padding):
raise ValueError(
"Please don't set both cuda_graph_config.batch_sizes "
Expand All @@ -2129,7 +2140,7 @@ def validate_cuda_graph_config(self) -> 'TorchLlmArgs':
config.max_batch_size = max(config.batch_sizes)
else:
max_batch_size = config.max_batch_size or 128
generated_sizes = self._generate_cuda_graph_batch_sizes(
generated_sizes = CudaGraphConfig._generate_cuda_graph_batch_sizes(
max_batch_size, config.enable_padding)
config.batch_sizes = generated_sizes
config.max_batch_size = max_batch_size
Expand Down
Loading