Skip to content

Commit ad4a9dc

Browse files
authored
[cuda] manually import the correct pynvml module (#12679)
fixes problems like #12635 and #12636 and #12565 --------- Signed-off-by: youkaichao <[email protected]>
1 parent b998645 commit ad4a9dc

File tree

3 files changed

+56
-9
lines changed

3 files changed

+56
-9
lines changed

vllm/platforms/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ def cuda_platform_plugin() -> Optional[str]:
3333
is_cuda = False
3434

3535
try:
36-
import pynvml
36+
from vllm.utils import import_pynvml
37+
pynvml = import_pynvml()
3738
pynvml.nvmlInit()
3839
try:
3940
if pynvml.nvmlDeviceGetCount() > 0:

vllm/platforms/cuda.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,14 @@
88
from typing import (TYPE_CHECKING, Callable, List, Optional, Tuple, TypeVar,
99
Union)
1010

11-
import pynvml
1211
import torch
1312
from typing_extensions import ParamSpec
1413

1514
# import custom ops, trigger op registration
1615
import vllm._C # noqa
1716
import vllm.envs as envs
1817
from vllm.logger import init_logger
18+
from vllm.utils import import_pynvml
1919

2020
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
2121

@@ -29,13 +29,7 @@
2929
_P = ParamSpec("_P")
3030
_R = TypeVar("_R")
3131

32-
if pynvml.__file__.endswith("__init__.py"):
33-
logger.warning(
34-
"You are using a deprecated `pynvml` package. Please install"
35-
" `nvidia-ml-py` instead, and make sure to uninstall `pynvml`."
36-
" When both of them are installed, `pynvml` will take precedence"
37-
" and cause errors. See https://pypi.org/project/pynvml "
38-
"for more information.")
32+
pynvml = import_pynvml()
3933

4034
# pytorch 2.5 uses cudnn sdpa by default, which will cause crash on some models
4135
# see https://github.com/huggingface/diffusers/issues/9704 for details

vllm/utils.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2208,3 +2208,55 @@ def run_method(obj: Any, method: Union[str, bytes, Callable], args: Tuple[Any],
22082208
else:
22092209
func = partial(method, obj) # type: ignore
22102210
return func(*args, **kwargs)
2211+
2212+
2213+
def import_pynvml():
2214+
"""
2215+
Historical comments:
2216+
2217+
libnvml.so is the library behind nvidia-smi, and
2218+
pynvml is a Python wrapper around it. We use it to get GPU
2219+
status without initializing CUDA context in the current process.
2220+
Historically, there are two packages that provide pynvml:
2221+
- `nvidia-ml-py` (https://pypi.org/project/nvidia-ml-py/): The official
2222+
wrapper. It is a dependency of vLLM, and is installed when users
2223+
install vLLM. It provides a Python module named `pynvml`.
2224+
- `pynvml` (https://pypi.org/project/pynvml/): An unofficial wrapper.
2225+
Prior to version 12.0, it also provides a Python module `pynvml`,
2226+
and therefore conflicts with the official one. What's worse,
2227+
the module is a Python package, and has higher priority than
2228+
the official one which is a standalone Python file.
2229+
This causes errors when both of them are installed.
2230+
Starting from version 12.0, it migrates to a new module
2231+
named `pynvml_utils` to avoid the conflict.
2232+
2233+
TL;DR: if users have pynvml<12.0 installed, it will cause problems.
2234+
Otherwise, `import pynvml` will import the correct module.
2235+
We take the safest approach here, to manually import the correct
2236+
`pynvml.py` module from the `nvidia-ml-py` package.
2237+
"""
2238+
if TYPE_CHECKING:
2239+
import pynvml
2240+
return pynvml
2241+
if "pynvml" in sys.modules:
2242+
import pynvml
2243+
if pynvml.__file__.endswith("__init__.py"):
2244+
# this is pynvml < 12.0
2245+
raise RuntimeError(
2246+
"You are using a deprecated `pynvml` package. "
2247+
"Please uninstall `pynvml` or upgrade to at least"
2248+
" version 12.0. See https://pypi.org/project/pynvml "
2249+
"for more information.")
2250+
return sys.modules["pynvml"]
2251+
import importlib.util
2252+
import os
2253+
import site
2254+
for site_dir in site.getsitepackages():
2255+
pynvml_path = os.path.join(site_dir, "pynvml.py")
2256+
if os.path.exists(pynvml_path):
2257+
spec = importlib.util.spec_from_file_location(
2258+
"pynvml", pynvml_path)
2259+
pynvml = importlib.util.module_from_spec(spec)
2260+
sys.modules["pynvml"] = pynvml
2261+
spec.loader.exec_module(pynvml)
2262+
return pynvml

0 commit comments

Comments
 (0)