@@ -2208,3 +2208,55 @@ def run_method(obj: Any, method: Union[str, bytes, Callable], args: Tuple[Any],
22082208    else :
22092209        func  =  partial (method , obj )  # type: ignore 
22102210    return  func (* args , ** kwargs )
2211+ 
2212+ 
2213+ def  import_pynvml ():
2214+     """ 
2215+     Historical comments: 
2216+ 
2217+     libnvml.so is the library behind nvidia-smi, and 
2218+     pynvml is a Python wrapper around it. We use it to get GPU 
2219+     status without initializing CUDA context in the current process. 
2220+     Historically, there are two packages that provide pynvml: 
2221+     - `nvidia-ml-py` (https://pypi.org/project/nvidia-ml-py/): The official 
2222+         wrapper. It is a dependency of vLLM, and is installed when users 
2223+         install vLLM. It provides a Python module named `pynvml`. 
2224+     - `pynvml` (https://pypi.org/project/pynvml/): An unofficial wrapper. 
2225+         Prior to version 12.0, it also provides a Python module `pynvml`, 
2226+         and therefore conflicts with the official one. What's worse, 
2227+         the module is a Python package, and has higher priority than 
2228+         the official one which is a standalone Python file. 
2229+         This causes errors when both of them are installed. 
2230+         Starting from version 12.0, it migrates to a new module 
2231+         named `pynvml_utils` to avoid the conflict. 
2232+      
2233+     TL;DR: if users have pynvml<12.0 installed, it will cause problems. 
2234+     Otherwise, `import pynvml` will import the correct module. 
2235+     We take the safest approach here, to manually import the correct 
2236+     `pynvml.py` module from the `nvidia-ml-py` package. 
2237+     """ 
2238+     if  TYPE_CHECKING :
2239+         import  pynvml 
2240+         return  pynvml 
2241+     if  "pynvml"  in  sys .modules :
2242+         import  pynvml 
2243+         if  pynvml .__file__ .endswith ("__init__.py" ):
2244+             # this is pynvml < 12.0 
2245+             raise  RuntimeError (
2246+                 "You are using a deprecated `pynvml` package. " 
2247+                 "Please uninstall `pynvml` or upgrade to at least" 
2248+                 " version 12.0. See https://pypi.org/project/pynvml " 
2249+                 "for more information." )
2250+         return  sys .modules ["pynvml" ]
2251+     import  importlib .util 
2252+     import  os 
2253+     import  site 
2254+     for  site_dir  in  site .getsitepackages ():
2255+         pynvml_path  =  os .path .join (site_dir , "pynvml.py" )
2256+         if  os .path .exists (pynvml_path ):
2257+             spec  =  importlib .util .spec_from_file_location (
2258+                 "pynvml" , pynvml_path )
2259+             pynvml  =  importlib .util .module_from_spec (spec )
2260+             sys .modules ["pynvml" ] =  pynvml 
2261+             spec .loader .exec_module (pynvml )
2262+             return  pynvml 
0 commit comments