Skip to content

Commit 1ee7a08

Browse files
authored
[5830][feat] Improve LoRA cache memory control (#6220)
Signed-off-by: Amit Zuker <[email protected]>
1 parent 83e9765 commit 1ee7a08

File tree

13 files changed

+332
-90
lines changed

13 files changed

+332
-90
lines changed

examples/llm-api/llm_multilora.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
from tensorrt_llm import LLM
77
from tensorrt_llm.executor import LoRARequest
8-
from tensorrt_llm.llmapi import BuildConfig
98
from tensorrt_llm.lora_manager import LoraConfig
109

1110

@@ -19,12 +18,12 @@ def main():
1918

2019
# Currently, we need to pass at least one lora_dir to LLM constructor via build_config.lora_config.
2120
# This is necessary because it requires some configuration in the lora_dir to build the engine with LoRA support.
22-
build_config = BuildConfig()
23-
build_config.lora_config = LoraConfig(lora_dir=[lora_dir1])
21+
lora_config = LoraConfig(lora_dir=[lora_dir1],
22+
max_lora_rank=64,
23+
max_loras=3,
24+
max_cpu_loras=3)
2425
llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
25-
enable_lora=True,
26-
max_lora_rank=64,
27-
build_config=build_config)
26+
lora_config=lora_config)
2827

2928
# Sample prompts
3029
prompts = [

examples/llm-api/quickstart_multimodal.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,9 @@ def main():
148148
models_module = importlib.import_module('tensorrt_llm._torch.models')
149149
model_class = getattr(models_module, args.auto_model_name)
150150
lora_config = model_class.lora_config(args.model_dir)
151+
# For stability - explicitly set the LoRA GPU cache & CPU cache to have space for 2 adapters
152+
lora_config.max_loras = 2
153+
lora_config.max_cpu_loras = 2
151154

152155
llm, sampling_params = setup_llm(args, lora_config=lora_config)
153156

tensorrt_llm/_torch/models/modeling_phi4mm.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -271,16 +271,16 @@ def lora_request(num_requests: int, modality: str, base_model_dir: str):
271271
if modality == "image" or modality == "image_audio":
272272
lora_request = [
273273
LoRARequest(
274-
lora_name=f"vision-lora-{i}",
275-
lora_int_id=i,
274+
lora_name="vision-lora",
275+
lora_int_id=0,
276276
lora_path=f"{base_model_dir}/vision-lora",
277277
) for i in range(num_requests)
278278
]
279279
elif modality == "audio":
280280
lora_request = [
281281
LoRARequest(
282-
lora_name=f"speech-lora-{i}",
283-
lora_int_id=i,
282+
lora_name="speech-lora",
283+
lora_int_id=1,
284284
lora_path=f"{base_model_dir}/speech-lora",
285285
) for i in range(num_requests)
286286
]

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
1212
from tensorrt_llm._utils import str_dtype_to_binding, torch_dtype_to_str
1313
from tensorrt_llm.bindings.executor import DecodingMode, ExecutorConfig
14+
from tensorrt_llm.llmapi.llm_args import PeftCacheConfig
1415
from tensorrt_llm.logger import logger
1516
from tensorrt_llm.lora_manager import (LoraConfig,
1617
get_default_trtllm_modules_to_hf_modules,
@@ -480,12 +481,17 @@ def create_py_executor_instance(
480481
num_lora_modules = model_engine.model.model_config.pretrained_config.num_hidden_layers * \
481482
len(lora_config.lora_target_modules + lora_config.missing_qkv_modules)
482483

483-
executor_config.peft_cache_config = trtllm.PeftCacheConfig(
484-
num_device_module_layer=max_lora_rank * num_lora_modules *
485-
lora_config.max_loras,
486-
num_host_module_layer=max_lora_rank * num_lora_modules *
487-
lora_config.max_cpu_loras,
484+
peft_cache_config_model = PeftCacheConfig.from_pybind(
485+
executor_config.peft_cache_config
486+
) if executor_config.peft_cache_config is not None else PeftCacheConfig(
488487
)
488+
if lora_config.max_loras is not None:
489+
peft_cache_config_model.num_device_module_layer = \
490+
max_lora_rank * num_lora_modules * lora_config.max_loras
491+
if lora_config.max_cpu_loras is not None:
492+
peft_cache_config_model.num_host_module_layer = \
493+
max_lora_rank * num_lora_modules * lora_config.max_cpu_loras
494+
executor_config.peft_cache_config = peft_cache_config_model._to_pybind()
489495

490496
from tensorrt_llm.bindings import WorldConfig
491497
world_config = WorldConfig(

tensorrt_llm/llmapi/llm.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@
3131
from ..logger import logger
3232
from ..sampling_params import SamplingParams
3333
from .llm_args import (TORCH_LLMARGS_EXPLICIT_DOCSTRING,
34-
TRT_LLMARGS_EXPLICIT_DOCSTRING, PybindMirror,
35-
TorchLlmArgs, TrtLlmArgs)
34+
TRT_LLMARGS_EXPLICIT_DOCSTRING, PeftCacheConfig,
35+
PybindMirror, TorchLlmArgs, TrtLlmArgs)
3636
from .llm_utils import (CachedModelLoader, KvCacheRetentionConfig,
3737
LlmBuildStats, ModelLoader, _ModelRuntimeContext)
3838
from .mpi_session import MpiPoolSession, external_mpi_comm_available
@@ -815,19 +815,35 @@ def _build_model(self):
815815
if self.args.peft_cache_config is not None:
816816
self._executor_config.peft_cache_config = PybindMirror.maybe_to_pybind(
817817
self.args.peft_cache_config)
818-
elif self.args.build_config.plugin_config.lora_plugin:
818+
819+
lora_config = None
820+
if self.args.build_config.plugin_config.lora_plugin:
819821
engine_config = EngineConfig.from_json_file(self._engine_dir /
820822
"config.json")
821823
lora_config = engine_config.build_config.lora_config
824+
if self.args.lora_config is not None:
825+
logger.info(
826+
"Overriding lora_config from engine with lora_config from LLM args"
827+
)
828+
lora_config = self.args.lora_config
829+
822830
max_lora_rank = lora_config.max_lora_rank
823831
num_lora_modules = engine_config.pretrained_config.num_hidden_layers * \
824832
len(lora_config.lora_target_modules + lora_config.missing_qkv_modules)
825-
self._executor_config.peft_cache_config = tllm.PeftCacheConfig(
826-
num_device_module_layer=max_lora_rank * num_lora_modules *
827-
self.args.max_loras,
828-
num_host_module_layer=max_lora_rank * num_lora_modules *
829-
self.args.max_cpu_loras,
833+
834+
peft_cache_config_model = PeftCacheConfig.from_pybind(
835+
self._executor_config.peft_cache_config
836+
) if self._executor_config.peft_cache_config is not None else PeftCacheConfig(
837+
)
838+
if lora_config.max_loras is not None:
839+
peft_cache_config_model.num_device_module_layer = \
840+
max_lora_rank * num_lora_modules * lora_config.max_loras
841+
if lora_config.max_cpu_loras is not None:
842+
peft_cache_config_model.num_host_module_layer = \
843+
max_lora_rank * num_lora_modules * lora_config.max_cpu_loras
844+
self._executor_config.peft_cache_config = peft_cache_config_model._to_pybind(
830845
)
846+
831847
if self.args.decoding_config is not None:
832848
self._executor_config.decoding_config = self.args.decoding_config
833849
if self.args.guided_decoding_backend == 'xgrammar':
@@ -868,7 +884,7 @@ def _build_model(self):
868884
postprocess_tokenizer_dir=self.args.postprocess_tokenizer_dir,
869885
),
870886
is_llm_executor=True,
871-
lora_config=self.args.lora_config)
887+
lora_config=lora_config)
872888

873889

874890
@append_docstring(TORCH_LLM_DOCSTRING)

tensorrt_llm/llmapi/llm_args.py

Lines changed: 82 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@
33
import json
44
import math
55
import os
6+
import types
67
from abc import ABC, abstractmethod
78
from dataclasses import dataclass, field
89
from enum import Enum, EnumMeta
910
from pathlib import Path
1011
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, List, Literal, Optional,
11-
TypeAlias, Union)
12+
Type, TypeAlias, TypeVar, Union, get_args, get_origin)
1213

1314
import torch
1415
import yaml
@@ -60,6 +61,8 @@
6061

6162
# TODO[chunweiy]: move the following symbols back to utils scope, and remove the following import
6263

64+
TypeBaseModel = TypeVar("T", bound=BaseModel)
65+
6366

6467
def Field(default: Any = ...,
6568
*,
@@ -597,6 +600,62 @@ def pybind_equals(obj0, obj1):
597600
return False
598601
return True
599602

603+
@classmethod
604+
def from_pybind(cls: Type[TypeBaseModel],
605+
pybind_instance: "PybindMirror") -> TypeBaseModel:
606+
"""Construct an instance of the given class from the fields in the given
607+
pybind class instance.
608+
609+
Args:
610+
cls: Type of the class to construct, must be a subclass of pydantic
611+
BaseModel
612+
pybind_instance: Instance of the pybind class to construct from its
613+
fields
614+
615+
Notes:
616+
When a field value is None in the pybind class, but it's not
617+
optional and has a default value in the BaseModel class, it would
618+
get the default value defined in the BaseModel class.
619+
620+
Returns:
621+
Instance of the given class, populated with the fields of the given
622+
pybind instance
623+
""" # noqa: D205
624+
assert issubclass(cls, BaseModel)
625+
626+
# Some of the fields are optional in the C++ class but in python they aren't
627+
# optional and have a default value, so copy the value from C++ instance
628+
# only if it has a value, so otherwise the default value defined in the
629+
# python class would be set.
630+
def _is_optional_type(annotation: Any) -> bool:
631+
"""Returns True if a type annotation represents an Optional type
632+
(Optional[X]) or a Union type that includes None (Union[X, Y, None]
633+
or X | Y | None).
634+
""" # noqa: D205
635+
origin = get_origin(annotation)
636+
args = get_args(annotation)
637+
638+
# Union is for Optional[x]
639+
# UnionType is for the new | operation in Python 3.10+
640+
return (origin is Union
641+
or origin is types.UnionType) and type(None) in args
642+
643+
fields_non_optional_with_default_value_in_basemodel = {
644+
field_name
645+
for field_name, field_info in cls.model_fields.items()
646+
if not (_is_optional_type(field_info.annotation)
647+
and field_info.is_required())
648+
}
649+
650+
kwargs = {}
651+
cpp_fields = PybindMirror.get_pybind_variable_fields(
652+
type(pybind_instance))
653+
for field_name in cpp_fields:
654+
field_value = getattr(pybind_instance, field_name)
655+
if field_value is not None or field_name not in fields_non_optional_with_default_value_in_basemodel:
656+
kwargs[field_name] = field_value
657+
return cls(**kwargs)
658+
600659

601660
class PybindMirrorMeta(type(PybindMirror)):
602661
pass
@@ -694,11 +753,12 @@ class PeftCacheConfig(StrictBaseModel, PybindMirror):
694753
default=0,
695754
description=
696755
"number of max sized 1-layer 1-module adapterSize=1 sets of weights that can be stored in host cache"
697-
)
756+
", affects host cache size and overrides value of host_cache_size")
698757
num_device_module_layer: int = Field(
699758
default=0,
700759
description=
701-
"number of max sized 1-layer 1-module sets of weights that can be stored in host cache"
760+
"number of max sized 1-layer 1-module sets of weights that can be stored in device cache"
761+
", affects device cache size and overrides value of device_cache_percent"
702762
)
703763
optimal_adapter_size: int = Field(
704764
default=
@@ -725,15 +785,17 @@ class PeftCacheConfig(StrictBaseModel, PybindMirror):
725785
max_pages_per_block_device: int = Field(
726786
default=8,
727787
description="Number of cache pages per allocation block (device)")
728-
device_cache_percent: Optional[float] = Field(
729-
default=None,
730-
description="percent of memory after engine load to use for cache")
731-
host_cache_size: Optional[int] = Field(
732-
default=None, description="size in bytes to use for host cache")
788+
device_cache_percent: float = Field(
789+
default=0.02,
790+
description=
791+
"Proportion of free device memory after engine load to use for cache, as a fraction from 0 to 1"
792+
)
793+
host_cache_size: int = Field(
794+
default=1024**3, description="size in bytes to use for host cache")
733795
lora_prefetch_dir: Optional[str] = Field(
734796
default=None,
735797
description=
736-
"folder to store the LoRA weights we hope to load during engine initialization"
798+
"folder to store the LoRA weights we hope to load during engine initialization, currently not supported"
737799
)
738800

739801
def _to_pybind(self):
@@ -1083,27 +1145,6 @@ class BaseLlmArgs(StrictBaseModel):
10831145
# LoRA arguments
10841146
enable_lora: bool = Field(default=False, description="Enable LoRA.")
10851147

1086-
max_lora_rank: Optional[int] = Field(
1087-
default=None,
1088-
description="The maximum LoRA rank.",
1089-
deprecated="Use lora_config.max_lora_rank instead.",
1090-
status="deprecated",
1091-
)
1092-
1093-
max_loras: int = Field(
1094-
default=4,
1095-
description="The maximum number of LoRA.",
1096-
deprecated="Use lora_config.max_loras instead.",
1097-
status="deprecated",
1098-
)
1099-
1100-
max_cpu_loras: int = Field(
1101-
default=4,
1102-
description="The maximum number of LoRA on CPU.",
1103-
deprecated="Use lora_config.max_cpu_loras instead.",
1104-
status="deprecated",
1105-
)
1106-
11071148
lora_config: Optional[LoraConfig] = Field(
11081149
default=None, description="LoRA configuration for the model.")
11091150

@@ -1494,10 +1535,10 @@ def validate_build_config_remaining(self):
14941535
if self.parallel_config._world_size == 1 and self.build_config:
14951536
self.build_config.plugin_config.nccl_plugin = None
14961537

1497-
if self.enable_lora and self.lora_config is None and self.backend != 'pytorch':
1538+
if self.enable_lora and self.backend != 'pytorch':
14981539
self.build_config.plugin_config.lora_plugin = 'auto'
1499-
if self.max_lora_rank is not None:
1500-
self.build_config.lora_config.max_lora_rank = self.max_lora_rank
1540+
if self.lora_config is not None:
1541+
self.build_config.lora_config.max_lora_rank = self.lora_config.max_lora_rank
15011542

15021543
if hasattr(self,
15031544
'enable_prompt_adapter') and self.enable_prompt_adapter:
@@ -1601,16 +1642,6 @@ def validate_speculative_config(self):
16011642
@model_validator(mode="after")
16021643
def validate_lora_config_consistency(self):
16031644
if self.lora_config:
1604-
if self.max_lora_rank is not None:
1605-
logger.warning(
1606-
"max_lora_rank is ignored when lora_config is provided.")
1607-
if self.max_loras != self.lora_config.max_loras:
1608-
logger.warning(
1609-
"max_loras is ignored when lora_config is provided.")
1610-
if self.max_cpu_loras != self.lora_config.max_cpu_loras:
1611-
logger.warning(
1612-
"max_cpu_loras is ignored when lora_config is provided.")
1613-
16141645
if len(self.lora_config.lora_dir) == 0:
16151646
# TODO [TRTLLM-5173]
16161647
logger.warning(
@@ -1637,6 +1668,14 @@ def validate_lora_config_consistency(self):
16371668
default_trtllm_modules_to_hf_modules.keys())
16381669
return self
16391670

1671+
@model_validator(mode="after")
1672+
def validate_peft_cache_config(self):
1673+
if self.peft_cache_config is not None and self.peft_cache_config.lora_prefetch_dir is not None:
1674+
raise ValueError(
1675+
f"lora_prefetch_dir was set to '{self.peft_cache_config.lora_prefetch_dir}' "
1676+
"while LoRA prefetch is not supported")
1677+
return self
1678+
16401679
def _update_plugin_config(self, key: str, value: Any):
16411680
setattr(self.build_config.plugin_config, key, value)
16421681

tensorrt_llm/lora_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,8 +203,8 @@ class LoraConfig(DictConversion):
203203
max_lora_rank: int = 64
204204
lora_target_modules: List[str] = field(default_factory=list)
205205
trtllm_modules_to_hf_modules: Dict[str, str] = field(default_factory=dict)
206-
max_loras: int = 4
207-
max_cpu_loras: int = 4
206+
max_loras: int | None = None
207+
max_cpu_loras: int | None = None
208208

209209
def __post_init__(self):
210210
assert self.lora_ckpt_source in ["hf", "nemo"], (

tests/unittest/llmapi/apps/_test_openai_lora.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,9 @@ def temp_extra_llm_api_options_file():
3636
extra_llm_api_options_dict = {
3737
"lora_config": {
3838
"lora_target_modules": ['attn_q', 'attn_k', 'attn_v'],
39-
"max_lora_rank": 8
39+
"max_lora_rank": 8,
40+
"max_loras": 4,
41+
"max_cpu_loras": 4,
4042
}
4143
}
4244

tests/unittest/llmapi/apps/_test_trtllm_serve_lora.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@ def temp_extra_llm_api_options_file():
2525
extra_llm_api_options_dict = {
2626
"lora_config": {
2727
"lora_target_modules": ['attn_q', 'attn_k', 'attn_v'],
28-
"max_lora_rank": 8
28+
"max_lora_rank": 8,
29+
"max_loras": 4,
30+
"max_cpu_loras": 4,
2931
}
3032
}
3133

0 commit comments

Comments
 (0)