3
3
import json
4
4
import math
5
5
import os
6
+ import types
6
7
from abc import ABC , abstractmethod
7
8
from dataclasses import dataclass , field
8
9
from enum import Enum , EnumMeta
9
10
from pathlib import Path
10
11
from typing import (TYPE_CHECKING , Any , ClassVar , Dict , List , Literal , Optional ,
11
- TypeAlias , Union )
12
+ Type , TypeAlias , TypeVar , Union , get_args , get_origin )
12
13
13
14
import torch
14
15
import yaml
60
61
61
62
# TODO[chunweiy]: move the following symbols back to utils scope, and remove the following import
62
63
64
+ TypeBaseModel = TypeVar ("T" , bound = BaseModel )
65
+
63
66
64
67
def Field (default : Any = ...,
65
68
* ,
@@ -597,6 +600,62 @@ def pybind_equals(obj0, obj1):
597
600
return False
598
601
return True
599
602
603
+ @classmethod
604
+ def from_pybind (cls : Type [TypeBaseModel ],
605
+ pybind_instance : "PybindMirror" ) -> TypeBaseModel :
606
+ """Construct an instance of the given class from the fields in the given
607
+ pybind class instance.
608
+
609
+ Args:
610
+ cls: Type of the class to construct, must be a subclass of pydantic
611
+ BaseModel
612
+ pybind_instance: Instance of the pybind class to construct from its
613
+ fields
614
+
615
+ Notes:
616
+ When a field value is None in the pybind class, but it's not
617
+ optional and has a default value in the BaseModel class, it would
618
+ get the default value defined in the BaseModel class.
619
+
620
+ Returns:
621
+ Instance of the given class, populated with the fields of the given
622
+ pybind instance
623
+ """ # noqa: D205
624
+ assert issubclass (cls , BaseModel )
625
+
626
+ # Some of the fields are optional in the C++ class but in python they aren't
627
+ # optional and have a default value, so copy the value from C++ instance
628
+ # only if it has a value, so otherwise the default value defined in the
629
+ # python class would be set.
630
+ def _is_optional_type (annotation : Any ) -> bool :
631
+ """Returns True if a type annotation represents an Optional type
632
+ (Optional[X]) or a Union type that includes None (Union[X, Y, None]
633
+ or X | Y | None).
634
+ """ # noqa: D205
635
+ origin = get_origin (annotation )
636
+ args = get_args (annotation )
637
+
638
+ # Union is for Optional[x]
639
+ # UnionType is for the new | operation in Python 3.10+
640
+ return (origin is Union
641
+ or origin is types .UnionType ) and type (None ) in args
642
+
643
+ fields_non_optional_with_default_value_in_basemodel = {
644
+ field_name
645
+ for field_name , field_info in cls .model_fields .items ()
646
+ if not (_is_optional_type (field_info .annotation )
647
+ and field_info .is_required ())
648
+ }
649
+
650
+ kwargs = {}
651
+ cpp_fields = PybindMirror .get_pybind_variable_fields (
652
+ type (pybind_instance ))
653
+ for field_name in cpp_fields :
654
+ field_value = getattr (pybind_instance , field_name )
655
+ if field_value is not None or field_name not in fields_non_optional_with_default_value_in_basemodel :
656
+ kwargs [field_name ] = field_value
657
+ return cls (** kwargs )
658
+
600
659
601
660
class PybindMirrorMeta (type (PybindMirror )):
602
661
pass
@@ -694,11 +753,12 @@ class PeftCacheConfig(StrictBaseModel, PybindMirror):
694
753
default = 0 ,
695
754
description =
696
755
"number of max sized 1-layer 1-module adapterSize=1 sets of weights that can be stored in host cache"
697
- )
756
+ ", affects host cache size and overrides value of host_cache_size" )
698
757
num_device_module_layer : int = Field (
699
758
default = 0 ,
700
759
description =
701
- "number of max sized 1-layer 1-module sets of weights that can be stored in host cache"
760
+ "number of max sized 1-layer 1-module sets of weights that can be stored in device cache"
761
+ ", affects device cache size and overrides value of device_cache_percent"
702
762
)
703
763
optimal_adapter_size : int = Field (
704
764
default =
@@ -725,15 +785,17 @@ class PeftCacheConfig(StrictBaseModel, PybindMirror):
725
785
max_pages_per_block_device : int = Field (
726
786
default = 8 ,
727
787
description = "Number of cache pages per allocation block (device)" )
728
- device_cache_percent : Optional [float ] = Field (
729
- default = None ,
730
- description = "percent of memory after engine load to use for cache" )
731
- host_cache_size : Optional [int ] = Field (
732
- default = None , description = "size in bytes to use for host cache" )
788
+ device_cache_percent : float = Field (
789
+ default = 0.02 ,
790
+ description =
791
+ "Proportion of free device memory after engine load to use for cache, as a fraction from 0 to 1"
792
+ )
793
+ host_cache_size : int = Field (
794
+ default = 1024 ** 3 , description = "size in bytes to use for host cache" )
733
795
lora_prefetch_dir : Optional [str ] = Field (
734
796
default = None ,
735
797
description =
736
- "folder to store the LoRA weights we hope to load during engine initialization"
798
+ "folder to store the LoRA weights we hope to load during engine initialization, currently not supported "
737
799
)
738
800
739
801
def _to_pybind (self ):
@@ -1083,27 +1145,6 @@ class BaseLlmArgs(StrictBaseModel):
1083
1145
# LoRA arguments
1084
1146
enable_lora : bool = Field (default = False , description = "Enable LoRA." )
1085
1147
1086
- max_lora_rank : Optional [int ] = Field (
1087
- default = None ,
1088
- description = "The maximum LoRA rank." ,
1089
- deprecated = "Use lora_config.max_lora_rank instead." ,
1090
- status = "deprecated" ,
1091
- )
1092
-
1093
- max_loras : int = Field (
1094
- default = 4 ,
1095
- description = "The maximum number of LoRA." ,
1096
- deprecated = "Use lora_config.max_loras instead." ,
1097
- status = "deprecated" ,
1098
- )
1099
-
1100
- max_cpu_loras : int = Field (
1101
- default = 4 ,
1102
- description = "The maximum number of LoRA on CPU." ,
1103
- deprecated = "Use lora_config.max_cpu_loras instead." ,
1104
- status = "deprecated" ,
1105
- )
1106
-
1107
1148
lora_config : Optional [LoraConfig ] = Field (
1108
1149
default = None , description = "LoRA configuration for the model." )
1109
1150
@@ -1494,10 +1535,10 @@ def validate_build_config_remaining(self):
1494
1535
if self .parallel_config ._world_size == 1 and self .build_config :
1495
1536
self .build_config .plugin_config .nccl_plugin = None
1496
1537
1497
- if self .enable_lora and self .lora_config is None and self . backend != 'pytorch' :
1538
+ if self .enable_lora and self .backend != 'pytorch' :
1498
1539
self .build_config .plugin_config .lora_plugin = 'auto'
1499
- if self .max_lora_rank is not None :
1500
- self .build_config .lora_config .max_lora_rank = self .max_lora_rank
1540
+ if self .lora_config is not None :
1541
+ self .build_config .lora_config .max_lora_rank = self .lora_config . max_lora_rank
1501
1542
1502
1543
if hasattr (self ,
1503
1544
'enable_prompt_adapter' ) and self .enable_prompt_adapter :
@@ -1601,16 +1642,6 @@ def validate_speculative_config(self):
1601
1642
@model_validator (mode = "after" )
1602
1643
def validate_lora_config_consistency (self ):
1603
1644
if self .lora_config :
1604
- if self .max_lora_rank is not None :
1605
- logger .warning (
1606
- "max_lora_rank is ignored when lora_config is provided." )
1607
- if self .max_loras != self .lora_config .max_loras :
1608
- logger .warning (
1609
- "max_loras is ignored when lora_config is provided." )
1610
- if self .max_cpu_loras != self .lora_config .max_cpu_loras :
1611
- logger .warning (
1612
- "max_cpu_loras is ignored when lora_config is provided." )
1613
-
1614
1645
if len (self .lora_config .lora_dir ) == 0 :
1615
1646
# TODO [TRTLLM-5173]
1616
1647
logger .warning (
@@ -1637,6 +1668,14 @@ def validate_lora_config_consistency(self):
1637
1668
default_trtllm_modules_to_hf_modules .keys ())
1638
1669
return self
1639
1670
1671
+ @model_validator (mode = "after" )
1672
+ def validate_peft_cache_config (self ):
1673
+ if self .peft_cache_config is not None and self .peft_cache_config .lora_prefetch_dir is not None :
1674
+ raise ValueError (
1675
+ f"lora_prefetch_dir was set to '{ self .peft_cache_config .lora_prefetch_dir } ' "
1676
+ "while LoRA prefetch is not supported" )
1677
+ return self
1678
+
1640
1679
def _update_plugin_config (self , key : str , value : Any ):
1641
1680
setattr (self .build_config .plugin_config , key , value )
1642
1681
0 commit comments