Skip to content

Commit 724e495

Browse files
authored
chore: partition LLM class into TorchLLM and TrtLLM (#4900)
Signed-off-by: Superjomn <[email protected]>
1 parent e44f768 commit 724e495

File tree

5 files changed

+162
-74
lines changed

5 files changed

+162
-74
lines changed

docs/source/helper.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -149,11 +149,18 @@ def generate_llmapi():
149149
content = underline("API Reference", "-") + "\n\n"
150150
for cls_name in public_classes_names:
151151
cls_name = cls_name.strip()
152-
content += (f".. autoclass:: tensorrt_llm.llmapi.{cls_name}\n"
153-
" :members:\n"
154-
" :undoc-members:\n"
155-
" :special-members: __init__\n"
156-
" :show-inheritance:\n")
152+
options = [
153+
" :members:", " :undoc-members:", " :show-inheritance:"
154+
]
155+
156+
if cls_name != 'LLM': # Conditionally add :special-members: __init__
157+
options.append(" :special-members: __init__")
158+
159+
if cls_name in ['TrtLLM', 'TorchLLM', 'LLM']:
160+
options.append(" :inherited-members:")
161+
162+
content += f".. autoclass:: tensorrt_llm.llmapi.{cls_name}\n"
163+
content += "\n".join(options) + "\n\n"
157164

158165
with open(doc_path, "w+") as f:
159166
f.write(content)

tensorrt_llm/_torch/llm.py

Lines changed: 2 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,3 @@
1-
from pathlib import Path
2-
from typing import Any, Literal, Optional, Union
1+
from tensorrt_llm.llmapi.llm import _TorchLLM as LLM
32

4-
from transformers import PreTrainedTokenizerBase
5-
6-
from ..llmapi.llm import LLM as BaseLLM
7-
from ..llmapi.llm import TokenizerBase
8-
9-
10-
class LLM(BaseLLM):
11-
12-
def __init__(self,
13-
model: str,
14-
tokenizer: Optional[Union[str, Path, TokenizerBase,
15-
PreTrainedTokenizerBase]] = None,
16-
tokenizer_mode: Literal['auto', 'slow'] = 'auto',
17-
skip_tokenizer_init: bool = False,
18-
trust_remote_code: bool = False,
19-
tensor_parallel_size: int = 1,
20-
dtype: str = "auto",
21-
revision: Optional[str] = None,
22-
tokenizer_revision: Optional[str] = None,
23-
**kwargs: Any):
24-
25-
kwargs_dict = dict(kwargs)
26-
kwargs_dict['backend'] = 'pytorch'
27-
super().__init__(model, tokenizer, tokenizer_mode, skip_tokenizer_init,
28-
trust_remote_code, tensor_parallel_size, dtype,
29-
revision, tokenizer_revision, **kwargs_dict)
3+
__all__ = ['LLM']

tensorrt_llm/llmapi/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from ..executor import CompletionOutput, RequestError
33
from ..sampling_params import GuidedDecodingParams, SamplingParams
44
from .build_cache import BuildCacheConfig
5-
from .llm import LLM, RequestOutput
5+
from .llm import LLM, RequestOutput, _TorchLLM, _TrtLLM
66
# yapf: disable
77
from .llm_args import (BatchingType, CacheTransceiverConfig, CalibConfig,
88
CapacitySchedulerPolicy, ContextChunkingPolicy,
@@ -50,4 +50,6 @@
5050
'LlmArgs',
5151
'TorchLlmArgs',
5252
'TrtLlmArgs',
53+
'_TrtLLM',
54+
'_TorchLLM',
5355
]

tensorrt_llm/llmapi/llm.py

Lines changed: 138 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,9 @@
3131
create_input_processor_with_hash, prompt_inputs)
3232
from ..logger import logger
3333
from ..sampling_params import SamplingParams
34-
from .llm_args import (LLMARGS_EXPLICIT_DOCSTRING, PybindMirror, TorchLlmArgs,
35-
TrtLlmArgs, _AutoDeployLlmArgs)
34+
from .llm_args import (TORCH_LLMARGS_EXPLICIT_DOCSTRING,
35+
TRT_LLMARGS_EXPLICIT_DOCSTRING, PybindMirror,
36+
TorchLlmArgs, TrtLlmArgs, _AutoDeployLlmArgs)
3637
from .llm_utils import (CachedModelLoader, KvCacheRetentionConfig,
3738
LlmBuildStats, ModelLoader, _ModelRuntimeContext)
3839
from .mpi_session import MpiPoolSession, external_mpi_comm_available
@@ -83,23 +84,26 @@ def _repr_fields(self):
8384
]
8485

8586

86-
LLM_DOCSTRING = LLMARGS_EXPLICIT_DOCSTRING + """
87-
kwargs (Any): Advanced arguments passed to `LlmArgs`.
87+
TRT_LLM_DOCSTRING = TRT_LLMARGS_EXPLICIT_DOCSTRING + """
8888
8989
Attributes:
9090
tokenizer (tensorrt_llm.llmapi.tokenizer.TokenizerBase, optional): The tokenizer loaded by LLM instance, if any.
9191
workspace (pathlib.Path): The directory to store intermediate files.
9292
llm_id (str): The unique ID of the LLM instance.
9393
"""
9494

95+
TORCH_LLM_DOCSTRING = TORCH_LLMARGS_EXPLICIT_DOCSTRING + """
9596
96-
@append_docstring(LLM_DOCSTRING)
97-
class LLM:
98-
"""LLM class is the main class for running a LLM model.
99-
100-
Parameters:
97+
Attributes:
98+
tokenizer (tensorrt_llm.llmapi.tokenizer.TokenizerBase, optional): The tokenizer loaded by LLM instance, if any.
10199
"""
102100

101+
102+
class BaseLLM:
103+
"""
104+
The base class for all LLM classes.
105+
"""
106+
103107
def __init__(self,
104108
model: Union[str, Path],
105109
tokenizer: Optional[Union[str, Path, TokenizerBase,
@@ -186,6 +190,8 @@ def __init__(self,
186190
if self._on_trt_backend:
187191
self._workspace = tempfile.TemporaryDirectory(
188192
suffix="-llm-workspace", dir=self.args.workspace)
193+
else:
194+
self._workspace = None
189195

190196
self._hf_model_dir: Optional[Path] = None
191197

@@ -202,10 +208,6 @@ def __init__(self,
202208
exception_handler.register(self, 'shutdown')
203209
atexit.register(LLM._shutdown_wrapper, weakref.ref(self))
204210

205-
@property
206-
def workspace(self) -> Path:
207-
return Path(self._workspace.name) if self._on_trt_backend else None
208-
209211
@property
210212
def llm_id(self) -> str:
211213
if self._llm_id is None:
@@ -584,7 +586,7 @@ def _check_arguments(self, prompt_len: int, query_len: int,
584586
def _build_model(self):
585587
model_loader = CachedModelLoader(self.args,
586588
mpi_session=self.mpi_session,
587-
workspace=self.workspace,
589+
workspace=self._workspace,
588590
llm_build_stats=weakref.proxy(
589591
self.llm_build_stats))
590592
self._engine_dir, self._hf_model_dir = model_loader()
@@ -766,6 +768,66 @@ def tokenizer(self) -> Optional[TokenizerBase]:
766768
def tokenizer(self, tokenizer: TokenizerBase):
767769
self._tokenizer = tokenizer
768770

771+
def shutdown(self) -> None:
772+
if hasattr(self, "_executor") and self._executor is not None:
773+
self._executor.shutdown()
774+
self._executor = None
775+
776+
if hasattr(self, 'mpi_session') and self.mpi_session is not None:
777+
self.mpi_session.shutdown()
778+
self.mpi_session = None
779+
780+
@staticmethod
781+
def _shutdown_wrapper(self_ref):
782+
# Retrieve the instance if it still exists
783+
instance = self_ref()
784+
if instance is not None:
785+
instance.shutdown()
786+
787+
def __enter__(self):
788+
return self
789+
790+
def __exit__(self, exc_type, exc_value, traceback) -> bool:
791+
del exc_value, traceback
792+
self.shutdown()
793+
return False # propagate exceptions
794+
795+
def __getstate__(self):
796+
raise RuntimeError("LLM object can not be pickled.")
797+
798+
def __del__(self):
799+
self.shutdown()
800+
801+
802+
@append_docstring(TRT_LLM_DOCSTRING)
803+
class _TrtLLM(BaseLLM):
804+
"""LLM class is the main class for running a LLM model using TensorRT-LLM backend.
805+
806+
Parameters:
807+
"""
808+
809+
def __init__(self,
810+
model: Union[str, Path],
811+
tokenizer: Optional[Union[str, Path, TokenizerBase,
812+
PreTrainedTokenizerBase]] = None,
813+
tokenizer_mode: Literal['auto', 'slow'] = 'auto',
814+
skip_tokenizer_init: bool = False,
815+
trust_remote_code: bool = False,
816+
tensor_parallel_size: int = 1,
817+
dtype: str = "auto",
818+
revision: Optional[str] = None,
819+
tokenizer_revision: Optional[str] = None,
820+
**kwargs: Any) -> None:
821+
# TODO: deprecate backend in LLM kwargs
822+
823+
super().__init__(model, tokenizer, tokenizer_mode, skip_tokenizer_init,
824+
trust_remote_code, tensor_parallel_size, dtype,
825+
revision, tokenizer_revision, **kwargs)
826+
827+
@property
828+
def workspace(self) -> Path:
829+
return Path(self._workspace.name) if self._on_trt_backend else None
830+
769831
def save(self, engine_dir: str) -> None:
770832
"""Save the built engine to the given path.
771833
@@ -791,32 +853,71 @@ def save(self, engine_dir: str) -> None:
791853
f"Copying {file} to {target_engine_dir / file.name}\n")
792854
shutil.copy(file, target_engine_dir / file.name)
793855

794-
def shutdown(self) -> None:
795-
if hasattr(self, "_executor") and self._executor is not None:
796-
self._executor.shutdown()
797-
self._executor = None
798856

799-
if hasattr(self, 'mpi_session') and self.mpi_session is not None:
800-
self.mpi_session.shutdown()
801-
self.mpi_session = None
857+
@append_docstring(TORCH_LLM_DOCSTRING)
858+
class _TorchLLM(BaseLLM):
859+
"""LLM class is the main class for running a LLM model using PyTorch backend.
802860
803-
@staticmethod
804-
def _shutdown_wrapper(self_ref):
805-
# Retrieve the instance if it still exists
806-
instance = self_ref()
807-
if instance is not None:
808-
instance.shutdown()
861+
Parameters:
862+
"""
809863

810-
def __enter__(self):
811-
return self
864+
def __init__(self,
865+
model: Union[str, Path],
866+
tokenizer: Optional[Union[str, Path, TokenizerBase,
867+
PreTrainedTokenizerBase]] = None,
868+
tokenizer_mode: Literal['auto', 'slow'] = 'auto',
869+
skip_tokenizer_init: bool = False,
870+
trust_remote_code: bool = False,
871+
tensor_parallel_size: int = 1,
872+
dtype: str = "auto",
873+
revision: Optional[str] = None,
874+
tokenizer_revision: Optional[str] = None,
875+
**kwargs: Any) -> None:
812876

813-
def __exit__(self, exc_type, exc_value, traceback) -> bool:
814-
del exc_value, traceback
815-
self.shutdown()
816-
return False # propagate exceptions
877+
# TODO: deprecate backend in LLM kwargs
878+
kwargs.pop("backend", None)
817879

818-
def __getstate__(self):
819-
raise RuntimeError("LLM object can not be pickled.")
880+
super().__init__(model,
881+
tokenizer,
882+
tokenizer_mode,
883+
skip_tokenizer_init,
884+
trust_remote_code,
885+
tensor_parallel_size,
886+
dtype,
887+
revision,
888+
tokenizer_revision,
889+
backend='pytorch',
890+
**kwargs)
820891

821-
def __del__(self):
822-
self.shutdown()
892+
893+
class LLM(_TrtLLM):
894+
895+
def __init__(self,
896+
model: Union[str, Path],
897+
tokenizer: Optional[Union[str, Path, TokenizerBase,
898+
PreTrainedTokenizerBase]] = None,
899+
tokenizer_mode: Literal['auto', 'slow'] = 'auto',
900+
skip_tokenizer_init: bool = False,
901+
trust_remote_code: bool = False,
902+
tensor_parallel_size: int = 1,
903+
dtype: str = "auto",
904+
revision: Optional[str] = None,
905+
tokenizer_revision: Optional[str] = None,
906+
**kwargs: Any) -> None:
907+
super().__init__(model, tokenizer, tokenizer_mode, skip_tokenizer_init,
908+
trust_remote_code, tensor_parallel_size, dtype,
909+
revision, tokenizer_revision, **kwargs)
910+
911+
912+
_LLM_REPR = "TrtLLM"
913+
914+
# sphinx will ignore the LLM's docstring if it is not explicitly set
915+
LLM.__doc__ = \
916+
f"""LLM class is the main class for running a LLM model.
917+
918+
This class is an alias of {_LLM_REPR}. You can switch between the TensorRT backend
919+
and the PyTorch backend by setting the TLLM_USE_TRT_ENGINE environment to 1 or 0.
920+
The default backend is the TensorRT backend.
921+
922+
Parameters:
923+
""" + TRT_LLM_DOCSTRING

tensorrt_llm/llmapi/llm_args.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1593,9 +1593,6 @@ def validate_enable_build_cache(self):
15931593

15941594
LlmArgs = TrtLlmArgs
15951595

1596-
LLMARGS_EXPLICIT_DOCSTRING = generate_api_docs_as_docstring(LlmArgs,
1597-
indent=' ' * 4)
1598-
15991596

16001597
class LoadFormat(Enum):
16011598
AUTO = 0
@@ -2068,3 +2065,10 @@ def get_model_format(model_dir: str) -> _ModelFormatKind:
20682065
)
20692066
else:
20702067
return model_format
2068+
2069+
2070+
TRT_LLMARGS_EXPLICIT_DOCSTRING = generate_api_docs_as_docstring(TrtLlmArgs,
2071+
indent=' ' * 4)
2072+
TORCH_LLMARGS_EXPLICIT_DOCSTRING = generate_api_docs_as_docstring(TorchLlmArgs,
2073+
indent=' ' *
2074+
4)

0 commit comments

Comments
 (0)