Skip to content
This repository was archived by the owner on Oct 11, 2024. It is now read-only.

Commit 0f9d0d4

Browse files
njhillrshaw@neuralmagic.com
authored andcommitted
[BugFix] Avoid unnecessary Ray import warnings (vllm-project#6079)
1 parent 99833da commit 0f9d0d4

File tree

3 files changed

+28
-9
lines changed

3 files changed

+28
-9
lines changed

vllm/config.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -714,11 +714,13 @@ def __init__(
714714

715715
from vllm.executor import ray_utils
716716
backend = "mp"
717-
ray_found = ray_utils.ray is not None
717+
ray_found = ray_utils.ray_is_available()
718718
if cuda_device_count_stateless() < self.world_size:
719719
if not ray_found:
720720
raise ValueError("Unable to load Ray which is "
721-
"required for multi-node inference")
721+
"required for multi-node inference, "
722+
"please install Ray with `pip install "
723+
"ray`.") from ray_utils.ray_import_err
722724
backend = "ray"
723725
elif ray_found:
724726
if self.placement_group:
@@ -750,6 +752,9 @@ def _verify_args(self) -> None:
750752
raise ValueError(
751753
"Unrecognized distributed executor backend. Supported values "
752754
"are 'ray' or 'mp'.")
755+
if self.distributed_executor_backend == "ray":
756+
from vllm.executor import ray_utils
757+
ray_utils.assert_ray_available()
753758
if not self.disable_custom_all_reduce and self.world_size > 1:
754759
if is_hip():
755760
self.disable_custom_all_reduce = True

vllm/engine/async_llm_engine.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,11 @@ def from_engine_args(
380380
"""Creates an async LLM engine from the engine arguments."""
381381
# Create the engine configs.
382382
engine_config = engine_args.create_engine_config()
383+
384+
if engine_args.engine_use_ray:
385+
from vllm.executor import ray_utils
386+
ray_utils.assert_ray_available()
387+
383388
distributed_executor_backend = (
384389
engine_config.parallel_config.distributed_executor_backend)
385390

vllm/executor/ray_utils.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,26 @@ def execute_model_compiled_dag_remote(self, ignored):
4242
output = pickle.dumps(output)
4343
return output
4444

45+
ray_import_err = None
46+
4547
except ImportError as e:
46-
logger.warning(
47-
"Failed to import Ray with %r. For multi-node inference, "
48-
"please install Ray with `pip install ray`.", e)
4948
ray = None # type: ignore
49+
ray_import_err = e
5050
RayWorkerWrapper = None # type: ignore
5151

5252

53+
def ray_is_available() -> bool:
54+
"""Returns True if Ray is available."""
55+
return ray is not None
56+
57+
58+
def assert_ray_available():
59+
"""Raise an exception if Ray is not available."""
60+
if ray is None:
61+
raise ValueError("Failed to import Ray, please install Ray with "
62+
"`pip install ray`.") from ray_import_err
63+
64+
5365
def initialize_ray_cluster(
5466
parallel_config: ParallelConfig,
5567
ray_address: Optional[str] = None,
@@ -65,10 +77,7 @@ def initialize_ray_cluster(
6577
ray_address: The address of the Ray cluster. If None, uses
6678
the default Ray cluster address.
6779
"""
68-
if ray is None:
69-
raise ImportError(
70-
"Ray is not installed. Please install Ray to use multi-node "
71-
"serving.")
80+
assert_ray_available()
7281

7382
# Connect to a ray cluster.
7483
if is_hip() or is_xpu():

0 commit comments

Comments
 (0)