88import tempfile
99import urllib .request
1010import warnings
11- from contextlib import contextmanager
1211from dataclasses import fields , replace
1312from enum import Enum
1413from pathlib import Path
1514from typing import (
1615 Any ,
1716 Callable ,
1817 Dict ,
19- Iterator ,
2018 List ,
2119 Optional ,
2220 Sequence ,
@@ -864,40 +862,52 @@ def is_platform_supported_for_trtllm(platform: str) -> bool:
864862 return True
865863
866864
867- @contextmanager
868- def download_plugin_lib_path (platform : str ) -> Iterator [str ]:
869- """
870- Downloads (if needed) and extracts the TensorRT-LLM plugin wheel for the specified platform,
871- then yields the path to the extracted shared library (.so or .dll).
865+ def _cache_root () -> Path :
866+ username = getpass .getuser ()
867+ return Path (tempfile .gettempdir ()) / f"torch_tensorrt_{ username } "
872868
873- The wheel file is cached in a user-specific temporary directory to avoid repeated downloads.
874- Extraction happens in a temporary directory that is cleaned up after use.
875869
876- Args :
877- platform (str): The platform identifier string (e.g., 'linux_x86_64') to select the correct wheel.
870+ def _extracted_dir_trtllm ( platform : str ) -> Path :
871+ return _cache_root () / "trtllm" / f" { __tensorrt_llm_version__ } _ { platform } "
878872
879- Yields:
880- str: The full path to the extracted TensorRT-LLM shared library file.
881873
882- Raises:
883- ImportError: If the 'zipfile' module is not available.
884- RuntimeError: If the wheel file is missing, corrupted, or extraction fails.
874+ def download_and_get_plugin_lib_path (platform : str ) -> Optional [str ]:
885875 """
886- plugin_lib_path = None
887- username = getpass .getuser ()
888- torchtrt_cache_dir = Path (tempfile .gettempdir ()) / f"torch_tensorrt_{ username } "
889- torchtrt_cache_dir .mkdir (parents = True , exist_ok = True )
890- file_name = f"tensorrt_llm-{ __tensorrt_llm_version__ } -{ _WHL_CPYTHON_VERSION } -{ _WHL_CPYTHON_VERSION } -{ platform } .whl"
891- torchtrt_cache_trtllm_whl = torchtrt_cache_dir / file_name
892- downloaded_file_path = torchtrt_cache_trtllm_whl
893-
894- if not torchtrt_cache_trtllm_whl .exists ():
895- # Downloading TRT-LLM lib
876+ Returns the path to the TensorRT‑LLM shared library, downloading and extracting if necessary.
877+
878+ Args:
879+ platform (str): Platform identifier (e.g., 'linux_x86_64')
880+
881+ Returns:
882+ Optional[str]: Path to shared library or None if operation fails.
883+ """
884+ wheel_filename = (
885+ f"tensorrt_llm-{ __tensorrt_llm_version__ } -{ _WHL_CPYTHON_VERSION } -"
886+ f"{ _WHL_CPYTHON_VERSION } -{ platform } .whl"
887+ )
888+ wheel_path = _cache_root () / wheel_filename
889+ extract_dir = _extracted_dir_trtllm (platform )
890+ # else will never be met though
891+ lib_filename = (
892+ "libnvinfer_plugin_tensorrt_llm.so"
893+ if "linux" in platform
894+ else "libnvinfer_plugin_tensorrt_llm.dll"
895+ )
896+ # eg: /tmp/torch_tensorrt_<username>/trtllm/0.17.0.post1_linux_x86_64/tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so
897+ plugin_lib_path = extract_dir / "tensorrt_llm" / "libs" / lib_filename
898+
899+ if plugin_lib_path .exists ():
900+ return str (plugin_lib_path )
901+
902+ wheel_path .parent .mkdir (parents = True , exist_ok = True )
903+ extract_dir .mkdir (parents = True , exist_ok = True )
904+
905+ if not wheel_path .exists ():
896906 base_url = "https://pypi.nvidia.com/tensorrt-llm/"
897- download_url = base_url + file_name
907+ download_url = base_url + wheel_filename
898908 try :
899909 logger .debug (f"Downloading { download_url } ..." )
900- urllib .request .urlretrieve (download_url , downloaded_file_path )
910+ urllib .request .urlretrieve (download_url , wheel_path )
901911 logger .debug ("Download succeeded and TRT-LLM wheel is now present" )
902912 except urllib .error .HTTPError as e :
903913 logger .error (
@@ -910,41 +920,45 @@ def download_plugin_lib_path(platform: str) -> Iterator[str]:
910920 except OSError as e :
911921 logger .error (f"Local file write error: { e } " )
912922
913- # Proceeding with the unzip of the wheel file in tmpdir
914- if "linux" in platform :
915- lib_filename = "libnvinfer_plugin_tensorrt_llm.so"
916- else :
917- # This condition is never met though
918- lib_filename = "libnvinfer_plugin_tensorrt_llm.dll"
923+ try :
924+ import zipfile
925+ except ImportError as e :
926+ raise ImportError (
927+ "zipfile module is required but not found. Please install zipfile"
928+ )
929+ try :
930+ with zipfile .ZipFile (wheel_path ) as zip_ref :
931+ zip_ref .extractall (extract_dir )
932+ logger .debug (f"Extracted wheel to { extract_dir } " )
933+ except FileNotFoundError as e :
934+ # This should capture the errors in the download failure above
935+ logger .error (f"Wheel file not found at { wheel_path } : { e } " )
936+ raise RuntimeError (
937+ f"Failed to find downloaded wheel file at { wheel_path } "
938+ ) from e
939+ except zipfile .BadZipFile as e :
940+ logger .error (f"Invalid or corrupted wheel file: { e } " )
941+ raise RuntimeError (
942+ "Downloaded wheel file is corrupted or not a valid zip archive"
943+ ) from e
944+ except Exception as e :
945+ logger .error (f"Unexpected error while extracting wheel: { e } " )
946+ raise RuntimeError (
947+ "Unexpected error during extraction of TensorRT-LLM wheel"
948+ ) from e
919949
920- with tempfile .TemporaryDirectory () as tmpdir :
921- try :
922- import zipfile
923- except ImportError :
924- raise ImportError (
925- "zipfile module is required but not found. Please install zipfile"
926- )
927- try :
928- with zipfile .ZipFile (downloaded_file_path , "r" ) as zip_ref :
929- zip_ref .extractall (tmpdir ) # Extract to a folder named 'tensorrt_llm'
930- except FileNotFoundError as e :
931- # This should capture the errors in the download failure above
932- logger .error (f"Wheel file not found at { downloaded_file_path } : { e } " )
933- raise RuntimeError (
934- f"Failed to find downloaded wheel file at { downloaded_file_path } "
935- ) from e
936- except zipfile .BadZipFile as e :
937- logger .error (f"Invalid or corrupted wheel file: { e } " )
938- raise RuntimeError (
939- "Downloaded wheel file is corrupted or not a valid zip archive"
940- ) from e
941- except Exception as e :
942- logger .error (f"Unexpected error while extracting wheel: { e } " )
943- raise RuntimeError (
944- "Unexpected error during extraction of TensorRT-LLM wheel"
945- ) from e
946- plugin_lib_path = os .path .join (tmpdir , "tensorrt_llm/libs" , lib_filename )
947- yield plugin_lib_path
950+ try :
951+ wheel_path .unlink (missing_ok = True )
952+ logger .debug (f"Deleted wheel file: { wheel_path } " )
953+ except Exception as e :
954+ logger .warning (f"Could not delete wheel file { wheel_path } : { e } " )
955+ if not plugin_lib_path .exists ():
956+ logger .error (
957+ f"Plugin library not found at expected location: { plugin_lib_path } "
958+ )
959+ return None
960+
961+ return str (plugin_lib_path )
948962
949963
950964def load_and_initialize_trtllm_plugin (plugin_lib_path : str ) -> bool :
@@ -1034,6 +1048,6 @@ def load_tensorrt_llm_for_nccl() -> bool:
10341048 )
10351049 return False
10361050
1037- with download_plugin_lib_path (platform ) as plugin_lib_path :
1038- return load_and_initialize_trtllm_plugin (plugin_lib_path )
1051+ plugin_lib_path = download_and_get_plugin_lib_path (platform )
1052+ return load_and_initialize_trtllm_plugin (plugin_lib_path ) # type: ignore[arg-type]
10391053 return False
0 commit comments