@@ -841,6 +841,29 @@ def is_tegra_platform() -> bool:
841841 return False
842842
843843
844+ def is_platform_supported_for_trtllm (platform : str ) -> bool :
845+ """
846+ Checks if the current platform supports TensorRT-LLM plugins for NCCL backend
847+ Returns:
848+ bool: True if the platform supports TensorRT-LLM plugins for NCCL backend, False otherwise.
849+ Note:
850+ TensorRT-LLM plugins for NCCL backend are not supported on:
851+ - Windows platforms
852+ - Jetson devices (aarch64 architecture)
853+ """
854+ if "windows" in platform :
855+ logger .info (
856+ "TensorRT-LLM plugins for NCCL backend are not supported on Windows"
857+ )
858+ return False
859+ if "aarch64" in platform :
860+ logger .info (
861+ "TensorRT-LLM plugins for NCCL backend are not supported on Jetson devices (aarch64)"
862+ )
863+ return False
864+ return True
865+
866+
844867@contextmanager
845868def download_plugin_lib_path (platform : str ) -> Iterator [str ]:
846869 """
@@ -891,6 +914,7 @@ def download_plugin_lib_path(platform: str) -> Iterator[str]:
891914 if "linux" in platform :
892915 lib_filename = "libnvinfer_plugin_tensorrt_llm.so"
893916 else :
917+ # This condition is never met though
894918 lib_filename = "libnvinfer_plugin_tensorrt_llm.dll"
895919
896920 with tempfile .TemporaryDirectory () as tmpdir :
@@ -923,7 +947,7 @@ def download_plugin_lib_path(platform: str) -> Iterator[str]:
923947 yield plugin_lib_path
924948
925949
926- def load_and_initialize_trtllm_plugin (plugin_lib_path : str , platform : str ) -> bool :
950+ def load_and_initialize_trtllm_plugin (plugin_lib_path : str ) -> bool :
927951 """
928952 Loads and initializes the TensorRT-LLM plugin from the given shared library path.
929953
@@ -933,9 +957,6 @@ def load_and_initialize_trtllm_plugin(plugin_lib_path: str, platform: str) -> bo
933957 Returns:
934958 bool: True if successful, False otherwise.
935959 """
936- if "windows" in platform :
937- logger .info ("NCCL backend is not supported on Windows" )
938- return False
939960 try :
940961 handle = ctypes .CDLL (plugin_lib_path )
941962 logger .info (f"Successfully loaded plugin library: { plugin_lib_path } " )
@@ -981,7 +1002,7 @@ def load_and_initialize_trtllm_plugin(plugin_lib_path: str, platform: str) -> bo
9811002 return False
9821003
9831004
984- def load_tensorrt_llm () -> bool :
1005+ def load_tensorrt_llm_for_nccl () -> bool :
9851006 """
9861007 Attempts to load the TensorRT-LLM plugin and initialize it.
9871008 Either the env variable TRTLLM_PLUGINS_PATH can specify the path
@@ -990,11 +1011,15 @@ def load_tensorrt_llm() -> bool:
9901011 Returns:
9911012 bool: True if the plugin was successfully loaded and initialized, False otherwise.
9921013 """
993- plugin_lib_path = os . environ . get ( "TRTLLM_PLUGINS_PATH" )
1014+ # Check platform compatibility first
9941015 platform = Platform .current_platform ()
9951016 platform = str (platform ).lower ()
1017+ if not is_platform_supported_for_trtllm (platform ):
1018+ return False
1019+ plugin_lib_path = os .environ .get ("TRTLLM_PLUGINS_PATH" )
1020+
9961021 if plugin_lib_path :
997- return load_and_initialize_trtllm_plugin (plugin_lib_path , platform )
1022+ return load_and_initialize_trtllm_plugin (plugin_lib_path )
9981023 else :
9991024 # this option can be used by user if TRTLLM_PLUGINS_PATH is not set by user
10001025 use_trtllm_plugin = os .environ .get ("USE_TRTLLM_PLUGINS" , "0" ).lower () in (
@@ -1010,5 +1035,5 @@ def load_tensorrt_llm() -> bool:
10101035 return False
10111036
10121037 with download_plugin_lib_path (platform ) as plugin_lib_path :
1013- return load_and_initialize_trtllm_plugin (plugin_lib_path , platform )
1038+ return load_and_initialize_trtllm_plugin (plugin_lib_path )
10141039 return False
0 commit comments