@@ -1008,87 +1008,51 @@ def args_bounds_check(
10081008 return args [i ] if len (args ) > i and args [i ] is not None else replacement
10091009
10101010
1011- def install_wget (platform : str ) -> None :
1012- if shutil .which ("wget" ):
1013- _LOGGER .debug ("wget is already installed" )
1014- return
1015- if platform .startswith ("linux" ):
1016- try :
1017- # if its root
1018- if os .geteuid () == 0 :
1019- subprocess .run (["apt-get" , "update" ], check = True )
1020- subprocess .run (["apt-get" , "install" , "-y" , "wget" ], check = True )
1021- else :
1022- _LOGGER .debug ("Please run with sudo permissions" )
1023- subprocess .run (["sudo" , "apt-get" , "update" ], check = True )
1024- subprocess .run (["sudo" , "apt-get" , "install" , "-y" , "wget" ], check = True )
1025- except subprocess .CalledProcessError as e :
1026- _LOGGER .debug ("Error installing wget:" , e )
1027-
1028-
1029- def install_mpi (platform : str ) -> None :
1030- if platform .startswith ("linux" ):
1031- try :
1032- # if its root
1033- if os .geteuid () == 0 :
1034- subprocess .run (["apt-get" , "update" ], check = True )
1035- subprocess .run (["apt-get" , "install" , "-y" , "libmpich-dev" ], check = True )
1036- subprocess .run (
1037- ["apt-get" , "install" , "-y" , "libopenmpi-dev" ], check = True
1038- )
1039- else :
1040- _LOGGER .debug ("Please run with sudo permissions" )
1041- subprocess .run (["sudo" , "apt-get" , "update" ], check = True )
1042- subprocess .run (
1043- ["sudo" , "apt-get" , "install" , "-y" , "libmpich-dev" ], check = True
1044- )
1045- subprocess .run (
1046- ["sudo" , "apt-get" , "install" , "-y" , "libopenmpi-dev" ], check = True
1047- )
1048- except subprocess .CalledProcessError as e :
1049- _LOGGER .debug ("Error installing mpi libs:" , e )
1050-
1051-
10521011def download_plugin_lib_path (py_version : str , platform : str ) -> str :
10531012 plugin_lib_path = None
1054- if py_version not in ("cp310" , "cp312" ):
1055- _LOGGER .warning (
1056- "No available wheel for python versions other than py3.10 and py3.12"
1057- )
1058- install_wget (platform )
1013+
1014+ # Downloading TRT-LLM lib
1015+ # TODO: check how to fix the 0.18.0 hardcode below
10591016 base_url = "https://pypi.nvidia.com/tensorrt-llm/"
1060- file_name = f"tensorrt_llm-0.17 .0.post1-{ py_version } -{ py_version } -{ platform } .whl"
1017+ file_name = f"tensorrt_llm-0.18 .0.post1-{ py_version } -{ py_version } -{ platform } .whl"
10611018 download_url = base_url + file_name
10621019 cmd = ["wget" , download_url ]
1063- try :
1064- if not (os .path .exists (file_name )):
1065- _LOGGER .info (f"Running command: { ' ' .join (cmd )} " )
1066- subprocess .run (cmd )
1067- _LOGGER .info ("Download complete of wheel" )
1068- if os .path .exists (file_name ):
1069- _LOGGER .info ("filename now present" )
1070- if os .path .exists ("./tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so" ):
1071- plugin_lib_path = (
1072- "./tensorrt_llm/libs/" + "libnvinfer_plugin_tensorrt_llm.so"
1073- )
1074- else :
1075- import zipfile
1020+ if not (os .path .exists (file_name )):
1021+ try :
1022+ subprocess .run (cmd , check = True )
1023+ _LOGGER .debug ("Download succeeded and TRT-LLM wheel is now present" )
1024+ except subprocess .CalledProcessError as e :
1025+ _LOGGER .error (
1026+ "Download failed (file not found or connection issue). Error code:" ,
1027+ e .returncode ,
1028+ )
1029+ except FileNotFoundError :
1030+ _LOGGER .error ("wget is required but not found. Please install wget." )
10761031
1077- with zipfile .ZipFile (file_name , "r" ) as zip_ref :
1078- zip_ref .extractall ("." ) # Extract to a folder named 'tensorrt_llm'
1079- plugin_lib_path = (
1080- "./tensorrt_llm/libs/" + "libnvinfer_plugin_tensorrt_llm.so"
1081- )
1082- except subprocess .CalledProcessError as e :
1083- _LOGGER .debug (f"Error occurred while trying to download: { e } " )
1084- except Exception as e :
1085- _LOGGER .debug (f"An unexpected error occurred: { e } " )
1032+ # Proceeding with the unzip of the wheel file
1033+ # This will exist if the filename was already downloaded
1034+ if os .path .exists ("./tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so" ):
1035+ plugin_lib_path = "./tensorrt_llm/libs/" + "libnvinfer_plugin_tensorrt_llm.so"
1036+ else :
1037+ try :
1038+ import zipfile
1039+ except :
1040+ raise ImportError (
1041+ "zipfile module is required but not found. Please install zipfile"
1042+ )
1043+ with zipfile .ZipFile (file_name , "r" ) as zip_ref :
1044+ zip_ref .extractall ("." ) # Extract to a folder named 'tensorrt_llm'
1045+ plugin_lib_path = (
1046+ "./tensorrt_llm/libs/" + "libnvinfer_plugin_tensorrt_llm.so"
1047+ )
10861048 return plugin_lib_path
10871049
10881050
10891051def load_tensorrt_llm () -> bool :
10901052 """
10911053 Attempts to load the TensorRT-LLM plugin and initialize it.
1054+ Either the env variable TRTLLM_PLUGINS_PATH specifies the path
1055+ If the above is not, the user can specify USE_TRTLLM_PLUGINS as either of 1, true, yes, on to download the TRT-LLM distribution and load it
10921056
10931057 Returns:
10941058 bool: True if the plugin was successfully loaded and initialized, False otherwise.
@@ -1098,8 +1062,9 @@ def load_tensorrt_llm() -> bool:
10981062 _LOGGER .warning (
10991063 "Please set the TRTLLM_PLUGINS_PATH to the directory containing libnvinfer_plugin_tensorrt_llm.so to use converters for torch.distributed ops or else set the USE_TRTLLM_PLUGINS variable to download the shared library" ,
11001064 )
1101- for key , value in os .environ .items ():
1102- print (f"{ key } : { value } " )
1065+ # for key, value in os.environ.items():
1066+ # print(f"{key}: {value}")
1067+ # this option can be used by user if TRTLLM_PLUGINS_PATH is not set by user
11031068 use_trtllm_plugin = os .environ .get ("USE_TRTLLM_PLUGINS" , "0" ).lower () in (
11041069 "1" ,
11051070 "true" ,
@@ -1112,14 +1077,14 @@ def load_tensorrt_llm() -> bool:
11121077 )
11131078 return False
11141079 else :
1115- py_version = f"cp{ sys .version_info .major } { sys .version_info .minor } "
1080+ # this is used as the default py version
1081+ py_version = f"cp312"
11161082 platform = Platform .current_platform ()
11171083
11181084 platform = str (platform ).lower ()
11191085 plugin_lib_path = download_plugin_lib_path (py_version , platform )
11201086 try :
1121- # Load the shared
1122- install_mpi (platform )
1087+ # Load the shared TRT-LLM file
11231088 handle = ctypes .CDLL (plugin_lib_path )
11241089 _LOGGER .info (f"Successfully loaded plugin library: { plugin_lib_path } " )
11251090 except OSError as e_os_error :
0 commit comments