61
61
# TODO[chunweiy]: move the following symbols back to utils scope, and remove the following import
62
62
63
63
64
- class CudaGraphConfig (BaseModel ):
64
+ class StrictBaseModel (BaseModel ):
65
+ """
66
+ A base model that forbids arbitrary fields.
67
+ """
68
+
69
+ class Config :
70
+ extra = "forbid" # globally forbid arbitrary fields
71
+
72
+
73
+ class CudaGraphConfig (StrictBaseModel ):
65
74
"""
66
75
Configuration for CUDA graphs.
67
76
"""
@@ -88,8 +97,40 @@ def validate_cuda_graph_max_batch_size(cls, v):
88
97
"cuda_graph_config.max_batch_size must be non-negative" )
89
98
return v
90
99
100
+ @staticmethod
101
+ def _generate_cuda_graph_batch_sizes (max_batch_size : int ,
102
+ enable_padding : bool ) -> List [int ]:
103
+ """Generate a list of batch sizes for CUDA graphs.
104
+
105
+ Args:
106
+ max_batch_size: Maximum batch size to generate up to
107
+ enable_padding: Whether padding is enabled, which affects the batch size distribution
108
+
109
+ Returns:
110
+ List of batch sizes to create CUDA graphs for
111
+ """
112
+ if enable_padding :
113
+ batch_sizes = [1 , 2 , 4 ] + [i * 8 for i in range (1 , 17 )]
114
+ else :
115
+ batch_sizes = list (range (1 , 32 )) + [32 , 64 , 128 ]
116
+
117
+ # Add powers of 2 up to max_batch_size
118
+ batch_sizes += [
119
+ 2 ** i for i in range (8 , math .floor (math .log (max_batch_size , 2 )))
120
+ ]
121
+
122
+ # Filter and sort batch sizes
123
+ batch_sizes = sorted (
124
+ [size for size in batch_sizes if size <= max_batch_size ])
125
+
126
+ # Add max_batch_size if not already included
127
+ if max_batch_size != batch_sizes [- 1 ]:
128
+ batch_sizes .append (max_batch_size )
129
+
130
+ return batch_sizes
131
+
91
132
92
- class MoeConfig (BaseModel ):
133
+ class MoeConfig (StrictBaseModel ):
93
134
"""
94
135
Configuration for MoE.
95
136
"""
@@ -194,7 +235,7 @@ def to_mapping(self) -> Mapping:
194
235
auto_parallel = self .auto_parallel )
195
236
196
237
197
- class CalibConfig (BaseModel ):
238
+ class CalibConfig (StrictBaseModel ):
198
239
"""
199
240
Calibration configuration.
200
241
"""
@@ -246,7 +287,7 @@ class _ModelFormatKind(Enum):
246
287
TLLM_ENGINE = 2
247
288
248
289
249
- class DecodingBaseConfig (BaseModel ):
290
+ class DecodingBaseConfig (StrictBaseModel ):
250
291
max_draft_len : Optional [int ] = None
251
292
speculative_model_dir : Optional [Union [str , Path ]] = None
252
293
@@ -267,6 +308,7 @@ def from_dict(cls, data: dict):
267
308
config_class = config_classes .get (decoding_type )
268
309
if config_class is None :
269
310
raise ValueError (f"Invalid decoding type: { decoding_type } " )
311
+ data .pop ("decoding_type" )
270
312
271
313
return config_class (** data )
272
314
@@ -465,7 +507,7 @@ def mirror_pybind_fields(pybind_class):
465
507
"""
466
508
467
509
def decorator (cls ):
468
- assert issubclass (cls , BaseModel )
510
+ assert issubclass (cls , StrictBaseModel )
469
511
# Get all non-private fields from the C++ class
470
512
cpp_fields = PybindMirror .get_pybind_variable_fields (pybind_class )
471
513
python_fields = set (cls .model_fields .keys ())
@@ -566,7 +608,7 @@ def _to_pybind(self):
566
608
567
609
568
610
@PybindMirror .mirror_pybind_fields (_DynamicBatchConfig )
569
- class DynamicBatchConfig (BaseModel , PybindMirror ):
611
+ class DynamicBatchConfig (StrictBaseModel , PybindMirror ):
570
612
"""Dynamic batch configuration.
571
613
572
614
Controls how batch size and token limits are dynamically adjusted at runtime.
@@ -592,7 +634,7 @@ def _to_pybind(self):
592
634
593
635
594
636
@PybindMirror .mirror_pybind_fields (_SchedulerConfig )
595
- class SchedulerConfig (BaseModel , PybindMirror ):
637
+ class SchedulerConfig (StrictBaseModel , PybindMirror ):
596
638
capacity_scheduler_policy : CapacitySchedulerPolicy = Field (
597
639
default = CapacitySchedulerPolicy .GUARANTEED_NO_EVICT ,
598
640
description = "The capacity scheduler policy to use" )
@@ -614,7 +656,7 @@ def _to_pybind(self):
614
656
615
657
616
658
@PybindMirror .mirror_pybind_fields (_PeftCacheConfig )
617
- class PeftCacheConfig (BaseModel , PybindMirror ):
659
+ class PeftCacheConfig (StrictBaseModel , PybindMirror ):
618
660
"""
619
661
Configuration for the PEFT cache.
620
662
"""
@@ -742,7 +784,7 @@ def supports_backend(self, backend: str) -> bool:
742
784
743
785
744
786
@PybindMirror .mirror_pybind_fields (_KvCacheConfig )
745
- class KvCacheConfig (BaseModel , PybindMirror ):
787
+ class KvCacheConfig (StrictBaseModel , PybindMirror ):
746
788
"""
747
789
Configuration for the KV cache.
748
790
"""
@@ -825,7 +867,7 @@ def _to_pybind(self):
825
867
826
868
827
869
@PybindMirror .mirror_pybind_fields (_ExtendedRuntimePerfKnobConfig )
828
- class ExtendedRuntimePerfKnobConfig (BaseModel , PybindMirror ):
870
+ class ExtendedRuntimePerfKnobConfig (StrictBaseModel , PybindMirror ):
829
871
"""
830
872
Configuration for extended runtime performance knobs.
831
873
"""
@@ -856,7 +898,7 @@ def _to_pybind(self):
856
898
857
899
858
900
@PybindMirror .mirror_pybind_fields (_CacheTransceiverConfig )
859
- class CacheTransceiverConfig (BaseModel , PybindMirror ):
901
+ class CacheTransceiverConfig (StrictBaseModel , PybindMirror ):
860
902
"""
861
903
Configuration for the cache transceiver.
862
904
"""
@@ -916,7 +958,7 @@ def model_name(self) -> Union[str, Path]:
916
958
return self .model if isinstance (self .model , str ) else None
917
959
918
960
919
- class BaseLlmArgs (BaseModel ):
961
+ class BaseLlmArgs (StrictBaseModel ):
920
962
"""
921
963
Base class for both TorchLlmArgs and TrtLlmArgs. It contains all the arguments that are common to both.
922
964
"""
@@ -1299,7 +1341,8 @@ def init_build_config(self):
1299
1341
"""
1300
1342
Creating a default BuildConfig if none is provided
1301
1343
"""
1302
- if self .build_config is None :
1344
+ build_config = getattr (self , "build_config" , None )
1345
+ if build_config is None :
1303
1346
kwargs = {}
1304
1347
if self .max_batch_size :
1305
1348
kwargs ["max_batch_size" ] = self .max_batch_size
@@ -1312,10 +1355,10 @@ def init_build_config(self):
1312
1355
if self .max_input_len :
1313
1356
kwargs ["max_input_len" ] = self .max_input_len
1314
1357
self .build_config = BuildConfig (** kwargs )
1315
-
1316
- assert isinstance (
1317
- self . build_config , BuildConfig
1318
- ), f"build_config is not initialized: { self . build_config } "
1358
+ else :
1359
+ assert isinstance (
1360
+ build_config ,
1361
+ BuildConfig ), f"build_config is not initialized: { build_config } "
1319
1362
return self
1320
1363
1321
1364
@model_validator (mode = "after" )
@@ -1752,7 +1795,7 @@ class LoadFormat(Enum):
1752
1795
DUMMY = 1
1753
1796
1754
1797
1755
- class TorchCompileConfig (BaseModel ):
1798
+ class TorchCompileConfig (StrictBaseModel ):
1756
1799
"""
1757
1800
Configuration for torch.compile.
1758
1801
"""
@@ -1966,38 +2009,6 @@ def validate_checkpoint_format(self):
1966
2009
1967
2010
return self
1968
2011
1969
- @staticmethod
1970
- def _generate_cuda_graph_batch_sizes (max_batch_size : int ,
1971
- enable_padding : bool ) -> List [int ]:
1972
- """Generate a list of batch sizes for CUDA graphs.
1973
-
1974
- Args:
1975
- max_batch_size: Maximum batch size to generate up to
1976
- enable_padding: Whether padding is enabled, which affects the batch size distribution
1977
-
1978
- Returns:
1979
- List of batch sizes to create CUDA graphs for
1980
- """
1981
- if enable_padding :
1982
- batch_sizes = [1 , 2 , 4 ] + [i * 8 for i in range (1 , 17 )]
1983
- else :
1984
- batch_sizes = list (range (1 , 32 )) + [32 , 64 , 128 ]
1985
-
1986
- # Add powers of 2 up to max_batch_size
1987
- batch_sizes += [
1988
- 2 ** i for i in range (8 , math .floor (math .log (max_batch_size , 2 )))
1989
- ]
1990
-
1991
- # Filter and sort batch sizes
1992
- batch_sizes = sorted (
1993
- [size for size in batch_sizes if size <= max_batch_size ])
1994
-
1995
- # Add max_batch_size if not already included
1996
- if max_batch_size != batch_sizes [- 1 ]:
1997
- batch_sizes .append (max_batch_size )
1998
-
1999
- return batch_sizes
2000
-
2001
2012
@model_validator (mode = "after" )
2002
2013
def validate_load_balancer (self ) -> 'TorchLlmArgs' :
2003
2014
from .._torch import MoeLoadBalancerConfig
@@ -2034,7 +2045,7 @@ def validate_cuda_graph_config(self) -> 'TorchLlmArgs':
2034
2045
if config .batch_sizes :
2035
2046
config .batch_sizes = sorted (config .batch_sizes )
2036
2047
if config .max_batch_size != 0 :
2037
- if config .batch_sizes != self ._generate_cuda_graph_batch_sizes (
2048
+ if config .batch_sizes != CudaGraphConfig ._generate_cuda_graph_batch_sizes (
2038
2049
config .max_batch_size , config .enable_padding ):
2039
2050
raise ValueError (
2040
2051
"Please don't set both cuda_graph_config.batch_sizes "
@@ -2046,7 +2057,7 @@ def validate_cuda_graph_config(self) -> 'TorchLlmArgs':
2046
2057
config .max_batch_size = max (config .batch_sizes )
2047
2058
else :
2048
2059
max_batch_size = config .max_batch_size or 128
2049
- generated_sizes = self ._generate_cuda_graph_batch_sizes (
2060
+ generated_sizes = CudaGraphConfig ._generate_cuda_graph_batch_sizes (
2050
2061
max_batch_size , config .enable_padding )
2051
2062
config .batch_sizes = generated_sizes
2052
2063
config .max_batch_size = max_batch_size
0 commit comments