31
31
create_input_processor_with_hash , prompt_inputs )
32
32
from ..logger import logger
33
33
from ..sampling_params import SamplingParams
34
- from .llm_args import (LLMARGS_EXPLICIT_DOCSTRING , PybindMirror , TorchLlmArgs ,
35
- TrtLlmArgs , _AutoDeployLlmArgs )
34
+ from .llm_args import (TORCH_LLMARGS_EXPLICIT_DOCSTRING ,
35
+ TRT_LLMARGS_EXPLICIT_DOCSTRING , PybindMirror ,
36
+ TorchLlmArgs , TrtLlmArgs , _AutoDeployLlmArgs )
36
37
from .llm_utils import (CachedModelLoader , KvCacheRetentionConfig ,
37
38
LlmBuildStats , ModelLoader , _ModelRuntimeContext )
38
39
from .mpi_session import MpiPoolSession , external_mpi_comm_available
@@ -83,23 +84,26 @@ def _repr_fields(self):
83
84
]
84
85
85
86
86
- LLM_DOCSTRING = LLMARGS_EXPLICIT_DOCSTRING + """
87
- kwargs (Any): Advanced arguments passed to `LlmArgs`.
87
+ TRT_LLM_DOCSTRING = TRT_LLMARGS_EXPLICIT_DOCSTRING + """
88
88
89
89
Attributes:
90
90
tokenizer (tensorrt_llm.llmapi.tokenizer.TokenizerBase, optional): The tokenizer loaded by LLM instance, if any.
91
91
workspace (pathlib.Path): The directory to store intermediate files.
92
92
llm_id (str): The unique ID of the LLM instance.
93
93
"""
94
94
95
+ TORCH_LLM_DOCSTRING = TORCH_LLMARGS_EXPLICIT_DOCSTRING + """
95
96
96
- @append_docstring (LLM_DOCSTRING )
97
- class LLM :
98
- """LLM class is the main class for running a LLM model.
99
-
100
- Parameters:
97
+ Attributes:
98
+ tokenizer (tensorrt_llm.llmapi.tokenizer.TokenizerBase, optional): The tokenizer loaded by LLM instance, if any.
101
99
"""
102
100
101
+
102
+ class BaseLLM :
103
+ """
104
+ The base class for all LLM classes.
105
+ """
106
+
103
107
def __init__ (self ,
104
108
model : Union [str , Path ],
105
109
tokenizer : Optional [Union [str , Path , TokenizerBase ,
@@ -186,6 +190,8 @@ def __init__(self,
186
190
if self ._on_trt_backend :
187
191
self ._workspace = tempfile .TemporaryDirectory (
188
192
suffix = "-llm-workspace" , dir = self .args .workspace )
193
+ else :
194
+ self ._workspace = None
189
195
190
196
self ._hf_model_dir : Optional [Path ] = None
191
197
@@ -202,10 +208,6 @@ def __init__(self,
202
208
exception_handler .register (self , 'shutdown' )
203
209
atexit .register (LLM ._shutdown_wrapper , weakref .ref (self ))
204
210
205
- @property
206
- def workspace (self ) -> Path :
207
- return Path (self ._workspace .name ) if self ._on_trt_backend else None
208
-
209
211
@property
210
212
def llm_id (self ) -> str :
211
213
if self ._llm_id is None :
@@ -584,7 +586,7 @@ def _check_arguments(self, prompt_len: int, query_len: int,
584
586
def _build_model (self ):
585
587
model_loader = CachedModelLoader (self .args ,
586
588
mpi_session = self .mpi_session ,
587
- workspace = self .workspace ,
589
+ workspace = self ._workspace ,
588
590
llm_build_stats = weakref .proxy (
589
591
self .llm_build_stats ))
590
592
self ._engine_dir , self ._hf_model_dir = model_loader ()
@@ -766,6 +768,66 @@ def tokenizer(self) -> Optional[TokenizerBase]:
766
768
def tokenizer (self , tokenizer : TokenizerBase ):
767
769
self ._tokenizer = tokenizer
768
770
771
+ def shutdown (self ) -> None :
772
+ if hasattr (self , "_executor" ) and self ._executor is not None :
773
+ self ._executor .shutdown ()
774
+ self ._executor = None
775
+
776
+ if hasattr (self , 'mpi_session' ) and self .mpi_session is not None :
777
+ self .mpi_session .shutdown ()
778
+ self .mpi_session = None
779
+
780
+ @staticmethod
781
+ def _shutdown_wrapper (self_ref ):
782
+ # Retrieve the instance if it still exists
783
+ instance = self_ref ()
784
+ if instance is not None :
785
+ instance .shutdown ()
786
+
787
+ def __enter__ (self ):
788
+ return self
789
+
790
+ def __exit__ (self , exc_type , exc_value , traceback ) -> bool :
791
+ del exc_value , traceback
792
+ self .shutdown ()
793
+ return False # propagate exceptions
794
+
795
+ def __getstate__ (self ):
796
+ raise RuntimeError ("LLM object can not be pickled." )
797
+
798
+ def __del__ (self ):
799
+ self .shutdown ()
800
+
801
+
802
+ @append_docstring (TRT_LLM_DOCSTRING )
803
+ class _TrtLLM (BaseLLM ):
804
+ """LLM class is the main class for running a LLM model using TensorRT-LLM backend.
805
+
806
+ Parameters:
807
+ """
808
+
809
+ def __init__ (self ,
810
+ model : Union [str , Path ],
811
+ tokenizer : Optional [Union [str , Path , TokenizerBase ,
812
+ PreTrainedTokenizerBase ]] = None ,
813
+ tokenizer_mode : Literal ['auto' , 'slow' ] = 'auto' ,
814
+ skip_tokenizer_init : bool = False ,
815
+ trust_remote_code : bool = False ,
816
+ tensor_parallel_size : int = 1 ,
817
+ dtype : str = "auto" ,
818
+ revision : Optional [str ] = None ,
819
+ tokenizer_revision : Optional [str ] = None ,
820
+ ** kwargs : Any ) -> None :
821
+ # TODO: deprecate backend in LLM kwargs
822
+
823
+ super ().__init__ (model , tokenizer , tokenizer_mode , skip_tokenizer_init ,
824
+ trust_remote_code , tensor_parallel_size , dtype ,
825
+ revision , tokenizer_revision , ** kwargs )
826
+
827
+ @property
828
+ def workspace (self ) -> Path :
829
+ return Path (self ._workspace .name ) if self ._on_trt_backend else None
830
+
769
831
def save (self , engine_dir : str ) -> None :
770
832
"""Save the built engine to the given path.
771
833
@@ -791,32 +853,71 @@ def save(self, engine_dir: str) -> None:
791
853
f"Copying { file } to { target_engine_dir / file .name } \n " )
792
854
shutil .copy (file , target_engine_dir / file .name )
793
855
794
- def shutdown (self ) -> None :
795
- if hasattr (self , "_executor" ) and self ._executor is not None :
796
- self ._executor .shutdown ()
797
- self ._executor = None
798
856
799
- if hasattr ( self , 'mpi_session' ) and self . mpi_session is not None :
800
- self . mpi_session . shutdown ()
801
- self . mpi_session = None
857
+ @ append_docstring ( TORCH_LLM_DOCSTRING )
858
+ class _TorchLLM ( BaseLLM ):
859
+ """LLM class is the main class for running a LLM model using PyTorch backend.
802
860
803
- @staticmethod
804
- def _shutdown_wrapper (self_ref ):
805
- # Retrieve the instance if it still exists
806
- instance = self_ref ()
807
- if instance is not None :
808
- instance .shutdown ()
861
+ Parameters:
862
+ """
809
863
810
- def __enter__ (self ):
811
- return self
864
+ def __init__ (self ,
865
+ model : Union [str , Path ],
866
+ tokenizer : Optional [Union [str , Path , TokenizerBase ,
867
+ PreTrainedTokenizerBase ]] = None ,
868
+ tokenizer_mode : Literal ['auto' , 'slow' ] = 'auto' ,
869
+ skip_tokenizer_init : bool = False ,
870
+ trust_remote_code : bool = False ,
871
+ tensor_parallel_size : int = 1 ,
872
+ dtype : str = "auto" ,
873
+ revision : Optional [str ] = None ,
874
+ tokenizer_revision : Optional [str ] = None ,
875
+ ** kwargs : Any ) -> None :
812
876
813
- def __exit__ (self , exc_type , exc_value , traceback ) -> bool :
814
- del exc_value , traceback
815
- self .shutdown ()
816
- return False # propagate exceptions
877
+ # TODO: deprecate backend in LLM kwargs
878
+ kwargs .pop ("backend" , None )
817
879
818
- def __getstate__ (self ):
819
- raise RuntimeError ("LLM object can not be pickled." )
880
+ super ().__init__ (model ,
881
+ tokenizer ,
882
+ tokenizer_mode ,
883
+ skip_tokenizer_init ,
884
+ trust_remote_code ,
885
+ tensor_parallel_size ,
886
+ dtype ,
887
+ revision ,
888
+ tokenizer_revision ,
889
+ backend = 'pytorch' ,
890
+ ** kwargs )
820
891
821
- def __del__ (self ):
822
- self .shutdown ()
892
+
893
+ class LLM (_TrtLLM ):
894
+
895
+ def __init__ (self ,
896
+ model : Union [str , Path ],
897
+ tokenizer : Optional [Union [str , Path , TokenizerBase ,
898
+ PreTrainedTokenizerBase ]] = None ,
899
+ tokenizer_mode : Literal ['auto' , 'slow' ] = 'auto' ,
900
+ skip_tokenizer_init : bool = False ,
901
+ trust_remote_code : bool = False ,
902
+ tensor_parallel_size : int = 1 ,
903
+ dtype : str = "auto" ,
904
+ revision : Optional [str ] = None ,
905
+ tokenizer_revision : Optional [str ] = None ,
906
+ ** kwargs : Any ) -> None :
907
+ super ().__init__ (model , tokenizer , tokenizer_mode , skip_tokenizer_init ,
908
+ trust_remote_code , tensor_parallel_size , dtype ,
909
+ revision , tokenizer_revision , ** kwargs )
910
+
911
+
912
+ _LLM_REPR = "TrtLLM"
913
+
914
+ # sphinx will ignore the LLM's docstring if it is not explicitly set
915
+ LLM .__doc__ = \
916
+ f"""LLM class is the main class for running a LLM model.
917
+
918
+ This class is an alias of { _LLM_REPR } . You can switch between the TensorRT backend
919
+ and the PyTorch backend by setting the TLLM_USE_TRT_ENGINE environment to 1 or 0.
920
+ The default backend is the TensorRT backend.
921
+
922
+ Parameters:
923
+ """ + TRT_LLM_DOCSTRING
0 commit comments