22
33import ctypes
44import gc
5+ import getpass
56import logging
67import os
8+ import tempfile
79import urllib .request
810import warnings
11+ from contextlib import contextmanager
912from dataclasses import fields , replace
1013from enum import Enum
11- from typing import Any , Callable , Dict , List , Optional , Sequence , Tuple , Union
14+ from pathlib import Path
15+ from typing import (
16+ Any ,
17+ Callable ,
18+ Dict ,
19+ Iterator ,
20+ List ,
21+ Optional ,
22+ Sequence ,
23+ Tuple ,
24+ Union ,
25+ )
1226
1327import numpy as np
1428import sympy
3751RTOL = 5e-3
3852ATOL = 5e-3
3953CPU_DEVICE = "cpu"
54+ _WHL_CPYTHON_VERSION = "cp310"
4055
4156
4257class Frameworks (Enum ):
@@ -240,6 +255,19 @@ def set_log_level(parent_logger: Any, level: Any) -> None:
240255 """
241256 if parent_logger :
242257 parent_logger .setLevel (level )
258+ print ("Handlers for parent_logger:" , parent_logger .handlers )
259+ print ("bool check--" , parent_logger .hasHandlers ())
260+ if parent_logger .hasHandlers ():
261+ ch = logging .StreamHandler ()
262+ ch .setLevel (logging .DEBUG ) # Allow debug messages on handler
263+ formatter = logging .Formatter (
264+ "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
265+ )
266+ ch .setFormatter (formatter )
267+ parent_logger .addHandler (ch )
268+ print ("Logger level:" , parent_logger .level )
269+ # print("Parent logger level:", logger.parent.level)
270+ print ("Root logger level:" , logging .getLogger ().level )
243271
244272 if ENABLED_FEATURES .torch_tensorrt_runtime :
245273 if level == logging .DEBUG :
@@ -826,17 +854,41 @@ def is_tegra_platform() -> bool:
826854 return False
827855
828856
829- def download_plugin_lib_path (py_version : str , platform : str ) -> str :
830- plugin_lib_path = None
857+ @contextmanager
858+ def download_plugin_lib_path (platform : str ) -> Iterator [str ]:
859+ """
860+ Downloads (if needed) and extracts the TensorRT-LLM plugin wheel for the specified platform,
861+ then yields the path to the extracted shared library (.so or .dll).
831862
832- # Downloading TRT-LLM lib
833- base_url = "https://pypi.nvidia.com/tensorrt-llm/"
834- file_name = f"tensorrt_llm-{ __tensorrt_llm_version__ } -{ py_version } -{ py_version } -{ platform } .whl"
835- download_url = base_url + file_name
836- if not (os .path .exists (file_name )):
863+ The wheel file is cached in a user-specific temporary directory to avoid repeated downloads.
864+ Extraction happens in a temporary directory that is cleaned up after use.
865+
866+ Args:
867+ platform (str): The platform identifier string (e.g., 'linux_x86_64') to select the correct wheel.
868+
869+ Yields:
870+ str: The full path to the extracted TensorRT-LLM shared library file.
871+
872+ Raises:
873+ ImportError: If the 'zipfile' module is not available.
874+ RuntimeError: If the wheel file is missing, corrupted, or extraction fails.
875+ """
876+ plugin_lib_path = None
877+ username = getpass .getuser ()
878+ torchtrt_cache_dir = Path (tempfile .gettempdir ()) / f"torch_tensorrt_{ username } "
879+ torchtrt_cache_dir .mkdir (parents = True , exist_ok = True )
880+ file_name = f"tensorrt_llm-{ __tensorrt_llm_version__ } -{ _WHL_CPYTHON_VERSION } -{ _WHL_CPYTHON_VERSION } -{ platform } .whl"
881+ torchtrt_cache_trtllm_whl = torchtrt_cache_dir / file_name
882+ downloaded_file_path = torchtrt_cache_trtllm_whl
883+
884+ if not torchtrt_cache_trtllm_whl .exists ():
885+ # Downloading TRT-LLM lib
886+ base_url = "https://pypi.nvidia.com/tensorrt-llm/"
887+ download_url = base_url + file_name
888+ print ("Downloading TRT-LLM wheel" )
837889 try :
838890 logger .debug (f"Downloading { download_url } ..." )
839- urllib .request .urlretrieve (download_url , file_name )
891+ urllib .request .urlretrieve (download_url , downloaded_file_path )
840892 logger .debug ("Download succeeded and TRT-LLM wheel is now present" )
841893 except urllib .error .HTTPError as e :
842894 logger .error (
@@ -849,60 +901,53 @@ def download_plugin_lib_path(py_version: str, platform: str) -> str:
849901 except OSError as e :
850902 logger .error (f"Local file write error: { e } " )
851903
852- # Proceeding with the unzip of the wheel file
853- # This will exist if the filename was already downloaded
904+ # Proceeding with the unzip of the wheel file in tmpdir
854905 if "linux" in platform :
855906 lib_filename = "libnvinfer_plugin_tensorrt_llm.so"
856907 else :
857908 lib_filename = "libnvinfer_plugin_tensorrt_llm.dll"
858- plugin_lib_path = os .path .join ("./tensorrt_llm/libs" , lib_filename )
859- if os .path .exists (plugin_lib_path ):
860- return plugin_lib_path
861- try :
862- import zipfile
863- except ImportError as e :
864- raise ImportError (
865- "zipfile module is required but not found. Please install zipfile"
866- )
867- with zipfile .ZipFile (file_name , "r" ) as zip_ref :
868- zip_ref .extractall ("." ) # Extract to a folder named 'tensorrt_llm'
869- plugin_lib_path = "./tensorrt_llm/libs/" + lib_filename
870- return plugin_lib_path
871-
872909
873- def load_tensorrt_llm () -> bool :
910+ with tempfile .TemporaryDirectory () as tmpdir :
911+ try :
912+ import zipfile
913+ except ImportError :
914+ raise ImportError (
915+ "zipfile module is required but not found. Please install zipfile"
916+ )
917+ try :
918+ with zipfile .ZipFile (downloaded_file_path , "r" ) as zip_ref :
919+ zip_ref .extractall (tmpdir ) # Extract to a folder named 'tensorrt_llm'
920+ except FileNotFoundError as e :
921+ # This should capture the errors in the download failure above
922+ logger .error (f"Wheel file not found at { downloaded_file_path } : { e } " )
923+ raise RuntimeError (
924+ f"Failed to find downloaded wheel file at { downloaded_file_path } "
925+ ) from e
926+ except zipfile .BadZipFile as e :
927+ logger .error (f"Invalid or corrupted wheel file: { e } " )
928+ raise RuntimeError (
929+ "Downloaded wheel file is corrupted or not a valid zip archive"
930+ ) from e
931+ except Exception as e :
932+ logger .error (f"Unexpected error while extracting wheel: { e } " )
933+ raise RuntimeError (
934+ "Unexpected error during extraction of TensorRT-LLM wheel"
935+ ) from e
936+ plugin_lib_path = os .path .join (tmpdir , "tensorrt_llm/libs" , lib_filename )
937+ yield plugin_lib_path
938+
939+
940+ def load_and_initialize_trtllm_plugin (plugin_lib_path : str ) -> bool :
874941 """
875- Attempts to load the TensorRT-LLM plugin and initialize it.
876- Either the env variable TRTLLM_PLUGINS_PATH can specify the path
877- Or the user can specify USE_TRTLLM_PLUGINS as either of (1, true, yes, on) to download the TRT-LLM distribution and load it
942+ Loads and initializes the TensorRT-LLM plugin from the given shared library path.
943+
944+ Args:
945+ plugin_lib_path (str): Path to the shared TensorRT-LLM plugin library.
878946
879947 Returns:
880- bool: True if the plugin was successfully loaded and initialized , False otherwise.
948+ bool: True if successful , False otherwise.
881949 """
882- plugin_lib_path = os .environ .get ("TRTLLM_PLUGINS_PATH" )
883- if not plugin_lib_path :
884- # this option can be used by user if TRTLLM_PLUGINS_PATH is not set by user
885- use_trtllm_plugin = os .environ .get ("USE_TRTLLM_PLUGINS" , "0" ).lower () in (
886- "1" ,
887- "true" ,
888- "yes" ,
889- "on" ,
890- )
891- if not use_trtllm_plugin :
892- logger .warning (
893- "Neither TRTLLM_PLUGIN_PATH is set nor is it directed to download the shared library. Please set either of the two to use TRT-LLM libraries in torchTRT"
894- )
895- return False
896- else :
897- # this is used as the default py version
898- py_version = "cp310"
899- platform = Platform .current_platform ()
900-
901- platform = str (platform ).lower ()
902- plugin_lib_path = download_plugin_lib_path (py_version , platform )
903-
904950 try :
905- # Load the shared TRT-LLM file
906951 handle = ctypes .CDLL (plugin_lib_path )
907952 logger .info (f"Successfully loaded plugin library: { plugin_lib_path } " )
908953 except OSError as e_os_error :
@@ -915,14 +960,13 @@ def load_tensorrt_llm() -> bool:
915960 )
916961 else :
917962 logger .warning (
918- f"Failed to load libnvinfer_plugin_tensorrt_llm.so from { plugin_lib_path } "
919- f"Ensure the path is correct and the library is compatible" ,
963+ f"Failed to load libnvinfer_plugin_tensorrt_llm.so from { plugin_lib_path } . "
964+ f"Ensure the path is correct and the library is compatible. " ,
920965 exc_info = e_os_error ,
921966 )
922967 return False
923968
924969 try :
925- # Configure plugin initialization arguments
926970 handle .initTrtLlmPlugins .argtypes = [ctypes .c_void_p , ctypes .c_char_p ]
927971 handle .initTrtLlmPlugins .restype = ctypes .c_bool
928972 except AttributeError as e_plugin_unavailable :
@@ -933,9 +977,7 @@ def load_tensorrt_llm() -> bool:
933977 return False
934978
935979 try :
936- # Initialize the plugin
937- TRT_LLM_PLUGIN_NAMESPACE = "tensorrt_llm"
938- if handle .initTrtLlmPlugins (None , TRT_LLM_PLUGIN_NAMESPACE .encode ("utf-8" )):
980+ if handle .initTrtLlmPlugins (None , b"tensorrt_llm" ):
939981 logger .info ("TensorRT-LLM plugin successfully initialized" )
940982 return True
941983 else :
@@ -948,3 +990,37 @@ def load_tensorrt_llm() -> bool:
948990 )
949991 return False
950992 return False
993+
994+
995+ def load_tensorrt_llm () -> bool :
996+ """
997+ Attempts to load the TensorRT-LLM plugin and initialize it.
998+ Either the env variable TRTLLM_PLUGINS_PATH can specify the path
999+ Or the user can specify USE_TRTLLM_PLUGINS as either of (1, true, yes, on) to download the TRT-LLM distribution and load it
1000+
1001+ Returns:
1002+ bool: True if the plugin was successfully loaded and initialized, False otherwise.
1003+ """
1004+ plugin_lib_path = os .environ .get ("TRTLLM_PLUGINS_PATH" )
1005+ if plugin_lib_path :
1006+ return load_and_initialize_trtllm_plugin (plugin_lib_path )
1007+ else :
1008+ # this option can be used by user if TRTLLM_PLUGINS_PATH is not set by user
1009+ use_trtllm_plugin = os .environ .get ("USE_TRTLLM_PLUGINS" , "0" ).lower () in (
1010+ "1" ,
1011+ "true" ,
1012+ "yes" ,
1013+ "on" ,
1014+ )
1015+ if not use_trtllm_plugin :
1016+ logger .warning (
1017+ "Neither TRTLLM_PLUGIN_PATH is set nor is it directed to download the shared library. Please set either of the two to use TRT-LLM libraries in torchTRT"
1018+ )
1019+ return False
1020+ else :
1021+ platform = Platform .current_platform ()
1022+ platform = str (platform ).lower ()
1023+
1024+ with download_plugin_lib_path (platform ) as plugin_lib_path :
1025+ return load_and_initialize_trtllm_plugin (plugin_lib_path )
1026+ return False
0 commit comments