Skip to content

Commit 840b788

Browse files
committed
init
Signed-off-by: Superjomn <[email protected]>
1 parent 470544c commit 840b788

File tree

3 files changed

+290
-54
lines changed

3 files changed

+290
-54
lines changed

tensorrt_llm/llmapi/llm_args.py

Lines changed: 63 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,16 @@
6161
# TODO[chunweiy]: move the following symbols back to utils scope, and remove the following import
6262

6363

64-
class CudaGraphConfig(BaseModel):
64+
class StrictBaseModel(BaseModel):
65+
"""
66+
A base model that forbids arbitrary fields.
67+
"""
68+
69+
class Config:
70+
extra = "forbid" # globally forbid arbitrary fields
71+
72+
73+
class CudaGraphConfig(StrictBaseModel):
6574
"""
6675
Configuration for CUDA graphs.
6776
"""
@@ -88,8 +97,40 @@ def validate_cuda_graph_max_batch_size(cls, v):
8897
"cuda_graph_config.max_batch_size must be non-negative")
8998
return v
9099

100+
@staticmethod
101+
def _generate_cuda_graph_batch_sizes(max_batch_size: int,
102+
enable_padding: bool) -> List[int]:
103+
"""Generate a list of batch sizes for CUDA graphs.
104+
105+
Args:
106+
max_batch_size: Maximum batch size to generate up to
107+
enable_padding: Whether padding is enabled, which affects the batch size distribution
108+
109+
Returns:
110+
List of batch sizes to create CUDA graphs for
111+
"""
112+
if enable_padding:
113+
batch_sizes = [1, 2, 4] + [i * 8 for i in range(1, 17)]
114+
else:
115+
batch_sizes = list(range(1, 32)) + [32, 64, 128]
116+
117+
# Add powers of 2 up to max_batch_size
118+
batch_sizes += [
119+
2**i for i in range(8, math.floor(math.log(max_batch_size, 2)))
120+
]
121+
122+
# Filter and sort batch sizes
123+
batch_sizes = sorted(
124+
[size for size in batch_sizes if size <= max_batch_size])
125+
126+
# Add max_batch_size if not already included
127+
if max_batch_size != batch_sizes[-1]:
128+
batch_sizes.append(max_batch_size)
129+
130+
return batch_sizes
131+
91132

92-
class MoeConfig(BaseModel):
133+
class MoeConfig(StrictBaseModel):
93134
"""
94135
Configuration for MoE.
95136
"""
@@ -194,7 +235,7 @@ def to_mapping(self) -> Mapping:
194235
auto_parallel=self.auto_parallel)
195236

196237

197-
class CalibConfig(BaseModel):
238+
class CalibConfig(StrictBaseModel):
198239
"""
199240
Calibration configuration.
200241
"""
@@ -246,7 +287,7 @@ class _ModelFormatKind(Enum):
246287
TLLM_ENGINE = 2
247288

248289

249-
class DecodingBaseConfig(BaseModel):
290+
class DecodingBaseConfig(StrictBaseModel):
250291
max_draft_len: Optional[int] = None
251292
speculative_model_dir: Optional[Union[str, Path]] = None
252293

@@ -267,6 +308,7 @@ def from_dict(cls, data: dict):
267308
config_class = config_classes.get(decoding_type)
268309
if config_class is None:
269310
raise ValueError(f"Invalid decoding type: {decoding_type}")
311+
data.pop("decoding_type")
270312

271313
return config_class(**data)
272314

@@ -465,7 +507,7 @@ def mirror_pybind_fields(pybind_class):
465507
"""
466508

467509
def decorator(cls):
468-
assert issubclass(cls, BaseModel)
510+
assert issubclass(cls, StrictBaseModel)
469511
# Get all non-private fields from the C++ class
470512
cpp_fields = PybindMirror.get_pybind_variable_fields(pybind_class)
471513
python_fields = set(cls.model_fields.keys())
@@ -566,7 +608,7 @@ def _to_pybind(self):
566608

567609

568610
@PybindMirror.mirror_pybind_fields(_DynamicBatchConfig)
569-
class DynamicBatchConfig(BaseModel, PybindMirror):
611+
class DynamicBatchConfig(StrictBaseModel, PybindMirror):
570612
"""Dynamic batch configuration.
571613
572614
Controls how batch size and token limits are dynamically adjusted at runtime.
@@ -592,7 +634,7 @@ def _to_pybind(self):
592634

593635

594636
@PybindMirror.mirror_pybind_fields(_SchedulerConfig)
595-
class SchedulerConfig(BaseModel, PybindMirror):
637+
class SchedulerConfig(StrictBaseModel, PybindMirror):
596638
capacity_scheduler_policy: CapacitySchedulerPolicy = Field(
597639
default=CapacitySchedulerPolicy.GUARANTEED_NO_EVICT,
598640
description="The capacity scheduler policy to use")
@@ -614,7 +656,7 @@ def _to_pybind(self):
614656

615657

616658
@PybindMirror.mirror_pybind_fields(_PeftCacheConfig)
617-
class PeftCacheConfig(BaseModel, PybindMirror):
659+
class PeftCacheConfig(StrictBaseModel, PybindMirror):
618660
"""
619661
Configuration for the PEFT cache.
620662
"""
@@ -742,7 +784,7 @@ def supports_backend(self, backend: str) -> bool:
742784

743785

744786
@PybindMirror.mirror_pybind_fields(_KvCacheConfig)
745-
class KvCacheConfig(BaseModel, PybindMirror):
787+
class KvCacheConfig(StrictBaseModel, PybindMirror):
746788
"""
747789
Configuration for the KV cache.
748790
"""
@@ -825,7 +867,7 @@ def _to_pybind(self):
825867

826868

827869
@PybindMirror.mirror_pybind_fields(_ExtendedRuntimePerfKnobConfig)
828-
class ExtendedRuntimePerfKnobConfig(BaseModel, PybindMirror):
870+
class ExtendedRuntimePerfKnobConfig(StrictBaseModel, PybindMirror):
829871
"""
830872
Configuration for extended runtime performance knobs.
831873
"""
@@ -856,7 +898,7 @@ def _to_pybind(self):
856898

857899

858900
@PybindMirror.mirror_pybind_fields(_CacheTransceiverConfig)
859-
class CacheTransceiverConfig(BaseModel, PybindMirror):
901+
class CacheTransceiverConfig(StrictBaseModel, PybindMirror):
860902
"""
861903
Configuration for the cache transceiver.
862904
"""
@@ -916,7 +958,7 @@ def model_name(self) -> Union[str, Path]:
916958
return self.model if isinstance(self.model, str) else None
917959

918960

919-
class BaseLlmArgs(BaseModel):
961+
class BaseLlmArgs(StrictBaseModel):
920962
"""
921963
Base class for both TorchLlmArgs and TrtLlmArgs. It contains all the arguments that are common to both.
922964
"""
@@ -1299,7 +1341,8 @@ def init_build_config(self):
12991341
"""
13001342
Creating a default BuildConfig if none is provided
13011343
"""
1302-
if self.build_config is None:
1344+
build_config = getattr(self, "build_config", None)
1345+
if build_config is None:
13031346
kwargs = {}
13041347
if self.max_batch_size:
13051348
kwargs["max_batch_size"] = self.max_batch_size
@@ -1312,10 +1355,10 @@ def init_build_config(self):
13121355
if self.max_input_len:
13131356
kwargs["max_input_len"] = self.max_input_len
13141357
self.build_config = BuildConfig(**kwargs)
1315-
1316-
assert isinstance(
1317-
self.build_config, BuildConfig
1318-
), f"build_config is not initialized: {self.build_config}"
1358+
else:
1359+
assert isinstance(
1360+
build_config,
1361+
BuildConfig), f"build_config is not initialized: {build_config}"
13191362
return self
13201363

13211364
@model_validator(mode="after")
@@ -1752,7 +1795,7 @@ class LoadFormat(Enum):
17521795
DUMMY = 1
17531796

17541797

1755-
class TorchCompileConfig(BaseModel):
1798+
class TorchCompileConfig(StrictBaseModel):
17561799
"""
17571800
Configuration for torch.compile.
17581801
"""
@@ -1966,38 +2009,6 @@ def validate_checkpoint_format(self):
19662009

19672010
return self
19682011

1969-
@staticmethod
1970-
def _generate_cuda_graph_batch_sizes(max_batch_size: int,
1971-
enable_padding: bool) -> List[int]:
1972-
"""Generate a list of batch sizes for CUDA graphs.
1973-
1974-
Args:
1975-
max_batch_size: Maximum batch size to generate up to
1976-
enable_padding: Whether padding is enabled, which affects the batch size distribution
1977-
1978-
Returns:
1979-
List of batch sizes to create CUDA graphs for
1980-
"""
1981-
if enable_padding:
1982-
batch_sizes = [1, 2, 4] + [i * 8 for i in range(1, 17)]
1983-
else:
1984-
batch_sizes = list(range(1, 32)) + [32, 64, 128]
1985-
1986-
# Add powers of 2 up to max_batch_size
1987-
batch_sizes += [
1988-
2**i for i in range(8, math.floor(math.log(max_batch_size, 2)))
1989-
]
1990-
1991-
# Filter and sort batch sizes
1992-
batch_sizes = sorted(
1993-
[size for size in batch_sizes if size <= max_batch_size])
1994-
1995-
# Add max_batch_size if not already included
1996-
if max_batch_size != batch_sizes[-1]:
1997-
batch_sizes.append(max_batch_size)
1998-
1999-
return batch_sizes
2000-
20012012
@model_validator(mode="after")
20022013
def validate_load_balancer(self) -> 'TorchLlmArgs':
20032014
from .._torch import MoeLoadBalancerConfig
@@ -2034,7 +2045,7 @@ def validate_cuda_graph_config(self) -> 'TorchLlmArgs':
20342045
if config.batch_sizes:
20352046
config.batch_sizes = sorted(config.batch_sizes)
20362047
if config.max_batch_size != 0:
2037-
if config.batch_sizes != self._generate_cuda_graph_batch_sizes(
2048+
if config.batch_sizes != CudaGraphConfig._generate_cuda_graph_batch_sizes(
20382049
config.max_batch_size, config.enable_padding):
20392050
raise ValueError(
20402051
"Please don't set both cuda_graph_config.batch_sizes "
@@ -2046,7 +2057,7 @@ def validate_cuda_graph_config(self) -> 'TorchLlmArgs':
20462057
config.max_batch_size = max(config.batch_sizes)
20472058
else:
20482059
max_batch_size = config.max_batch_size or 128
2049-
generated_sizes = self._generate_cuda_graph_batch_sizes(
2060+
generated_sizes = CudaGraphConfig._generate_cuda_graph_batch_sizes(
20502061
max_batch_size, config.enable_padding)
20512062
config.batch_sizes = generated_sizes
20522063
config.max_batch_size = max_batch_size

tests/unittest/_torch/test_beam_search.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def llm_cuda_graph(fixed_params, input_prompts):
6363
enable_trtllm_sampler=True,
6464
max_beam_width=fixed_params["max_beam_width"],
6565
disable_overlap_scheduler=False,
66-
cuda_graph_config=CudaGraphConfig(enabled=True),
66+
cuda_graph_config=CudaGraphConfig(),
6767
)
6868

6969

0 commit comments

Comments
 (0)