|
3 | 3 | import functools |
4 | 4 | import logging |
5 | 5 | import os |
6 | | -import shutil |
7 | | -import subprocess |
8 | | -import sys |
9 | 6 | from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, overload |
10 | 7 |
|
11 | 8 | import numpy as np |
|
16 | 13 | from torch.fx.node import Argument, Target |
17 | 14 | from torch.fx.passes.shape_prop import TensorMetadata |
18 | 15 | from torch_tensorrt import _enums |
19 | | -from torch_tensorrt._enums import Platform |
20 | 16 | from torch_tensorrt.dynamo._settings import CompilationSettings |
21 | 17 | from torch_tensorrt.dynamo._SourceIR import SourceIR |
22 | 18 | from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext |
@@ -1006,123 +1002,3 @@ def args_bounds_check( |
1006 | 1002 | args: Tuple[Argument, ...], i: int, replacement: Optional[Any] = None |
1007 | 1003 | ) -> Any: |
1008 | 1004 | return args[i] if len(args) > i and args[i] is not None else replacement |
1009 | | - |
1010 | | - |
1011 | | -def download_plugin_lib_path(py_version: str, platform: str) -> str: |
1012 | | - plugin_lib_path = None |
1013 | | - |
1014 | | - # Downloading TRT-LLM lib |
1015 | | - # TODO: check how to fix the 0.18.0 hardcode below |
1016 | | - base_url = "https://pypi.nvidia.com/tensorrt-llm/" |
1017 | | - file_name = f"tensorrt_llm-0.18.0-{py_version}-{py_version}-{platform}.whl" |
1018 | | - download_url = base_url + file_name |
1019 | | - cmd = ["wget", download_url] |
1020 | | - if not (os.path.exists(file_name)): |
1021 | | - try: |
1022 | | - subprocess.run(cmd, check=True) |
1023 | | - _LOGGER.debug("Download succeeded and TRT-LLM wheel is now present") |
1024 | | - except subprocess.CalledProcessError as e: |
1025 | | - _LOGGER.error( |
1026 | | - "Download failed (file not found or connection issue). Error code:", |
1027 | | - e.returncode, |
1028 | | - ) |
1029 | | - except FileNotFoundError: |
1030 | | - _LOGGER.error("wget is required but not found. Please install wget.") |
1031 | | - |
1032 | | - # Proceeding with the unzip of the wheel file |
1033 | | - # This will exist if the filename was already downloaded |
1034 | | - if os.path.exists("./tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so"): |
1035 | | - plugin_lib_path = "./tensorrt_llm/libs/" + "libnvinfer_plugin_tensorrt_llm.so" |
1036 | | - else: |
1037 | | - try: |
1038 | | - import zipfile |
1039 | | - except: |
1040 | | - raise ImportError( |
1041 | | - "zipfile module is required but not found. Please install zipfile" |
1042 | | - ) |
1043 | | - with zipfile.ZipFile(file_name, "r") as zip_ref: |
1044 | | - zip_ref.extractall(".") # Extract to a folder named 'tensorrt_llm' |
1045 | | - plugin_lib_path = ( |
1046 | | - "./tensorrt_llm/libs/" + "libnvinfer_plugin_tensorrt_llm.so" |
1047 | | - ) |
1048 | | - return plugin_lib_path |
1049 | | - |
1050 | | - |
1051 | | -def load_tensorrt_llm() -> bool: |
1052 | | - """ |
1053 | | - Attempts to load the TensorRT-LLM plugin and initialize it. |
1054 | | - Either the env variable TRTLLM_PLUGINS_PATH can specify the path |
1055 | | - Or the user can specify USE_TRTLLM_PLUGINS as either of (1, true, yes, on) to download the TRT-LLM distribution and load it |
1056 | | -
|
1057 | | - Returns: |
1058 | | - bool: True if the plugin was successfully loaded and initialized, False otherwise. |
1059 | | - """ |
1060 | | - plugin_lib_path = os.environ.get("TRTLLM_PLUGINS_PATH") |
1061 | | - if not plugin_lib_path: |
1062 | | - # this option can be used by user if TRTLLM_PLUGINS_PATH is not set by user |
1063 | | - use_trtllm_plugin = os.environ.get("USE_TRTLLM_PLUGINS", "0").lower() in ( |
1064 | | - "1", |
1065 | | - "true", |
1066 | | - "yes", |
1067 | | - "on", |
1068 | | - ) |
1069 | | - if not use_trtllm_plugin: |
1070 | | - _LOGGER.warning( |
1071 | | - "Neither TRTLLM_PLUGIN_PATH is set nor is it directed to download the shared library. Please set either of the two to use TRT-LLM libraries in torchTRT" |
1072 | | - ) |
1073 | | - return False |
1074 | | - else: |
1075 | | - # this is used as the default py version |
1076 | | - py_version = f"cp312" |
1077 | | - platform = Platform.current_platform() |
1078 | | - |
1079 | | - platform = str(platform).lower() |
1080 | | - plugin_lib_path = download_plugin_lib_path(py_version, platform) |
1081 | | - |
1082 | | - try: |
1083 | | - # Load the shared TRT-LLM file |
1084 | | - handle = ctypes.CDLL(plugin_lib_path) |
1085 | | - _LOGGER.info(f"Successfully loaded plugin library: {plugin_lib_path}") |
1086 | | - except OSError as e_os_error: |
1087 | | - if "libmpi" in str(e_os_error): |
1088 | | - _LOGGER.warning( |
1089 | | - f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}. " |
1090 | | - f"The dependency libmpi.so is missing. " |
1091 | | - f"Please install the packages libmpich-dev and libopenmpi-dev.", |
1092 | | - exc_info=e_os_error, |
1093 | | - ) |
1094 | | - else: |
1095 | | - _LOGGER.warning( |
1096 | | - f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}" |
1097 | | - f"Ensure the path is correct and the library is compatible", |
1098 | | - exc_info=e_os_error, |
1099 | | - ) |
1100 | | - return False |
1101 | | - |
1102 | | - try: |
1103 | | - # Configure plugin initialization arguments |
1104 | | - handle.initTrtLlmPlugins.argtypes = [ctypes.c_void_p, ctypes.c_char_p] |
1105 | | - handle.initTrtLlmPlugins.restype = ctypes.c_bool |
1106 | | - except AttributeError as e_plugin_unavailable: |
1107 | | - _LOGGER.warning( |
1108 | | - "Unable to initialize the TensorRT-LLM plugin library", |
1109 | | - exc_info=e_plugin_unavailable, |
1110 | | - ) |
1111 | | - return False |
1112 | | - |
1113 | | - try: |
1114 | | - # Initialize the plugin |
1115 | | - TRT_LLM_PLUGIN_NAMESPACE = "tensorrt_llm" |
1116 | | - if handle.initTrtLlmPlugins(None, TRT_LLM_PLUGIN_NAMESPACE.encode("utf-8")): |
1117 | | - _LOGGER.info("TensorRT-LLM plugin successfully initialized") |
1118 | | - return True |
1119 | | - else: |
1120 | | - _LOGGER.warning("TensorRT-LLM plugin library failed in initialization") |
1121 | | - return False |
1122 | | - except Exception as e_initialization_error: |
1123 | | - _LOGGER.warning( |
1124 | | - "Exception occurred during TensorRT-LLM plugin library initialization", |
1125 | | - exc_info=e_initialization_error, |
1126 | | - ) |
1127 | | - return False |
1128 | | - return False |
0 commit comments