Skip to content

Conversation

sairampillai
Copy link

@sairampillai sairampillai commented Sep 26, 2025

[Bugfix] Improve GPU validation logging in Ray fallback scenarios

Adds early GPU count validation and clearer Ray placement error messages when tensor_parallel_size exceeds available GPUs to address poor logging and help users diagnose K8s deployment failures.

Related Issues

Fixes #25263

Purpose

Fixes poor logging when tensor_parallel_size exceeds available GPUs in Ray fallback scenarios.

When tensor_parallel_size is set higher than the available GPU count (e.g., tensor_parallel_size=4 with only 1 GPU), vLLM silently falls back to Ray executor without adequate warning. This causes confusing error messages in K8s deployments, where users see Ray placement group timeout errors without understanding the root cause.

Changes Made

  1. Early GPU validation in vllm/config/parallel.py: Added warning when tensor parallel size exceeds available GPUs during backend selection
  2. Enhanced Ray placement error messages in vllm/executor/ray_utils.py: Improved error messages in _wait_until_pg_ready() and initialize_ray_cluster() functions to provide context about GPU resource mismatches

Files Modified

  • vllm/config/parallel.py - Added GPU count validation with clear warnings
  • vllm/executor/ray_utils.py - Enhanced Ray placement group error handling

Test Plan

Scenario Testing

  1. Single GPU scenario with multi-GPU tensor parallel: Test with --tensor-parallel-size 4 on a system with only 1 available GPU
  2. K8s GPU resource mismatch: Verify error messages in constrained K8s environments where pod requests only 1 GPU but tensor parallel size > 1
  3. Normal operation: Ensure no impact when GPU resources match tensor parallel requirements

Test Commands

# Test 1: Check warning when tensor_parallel_size > available GPUs
python -c "
from vllm.config.parallel import ParallelConfig
from vllm.logger import init_logger
import logging
logging.basicConfig(level=logging.WARNING)
config = ParallelConfig(tensor_parallel_size=4)
print('Config test completed')
"

# Test 2: Ray integration test (requires multi-GPU setup)
PYTHONPATH=. python examples/offline_inference.py \
  --model microsoft/DialoGPT-small \
  --prompt "Hello world" \
  --tensor-parallel-size 2  # Will trigger validation if only 1 GPU

Functional Testing

  • Verify warning messages appear at correct configuration stages
  • Ensure normal operation remains unaffected with properly configured GPU resources
  • Test Ray cluster initialization warning when GPU mismatch detected

Test Result

Before Fix

  • No early warning when tensor_parallel_size exceeds available GPUs
  • Cryptic Ray placement group timeout errors:
    ValueError: Cannot provide a placement group of 'placement_group_specs=...' within 2550 seconds
    

After Fix

  • Early warning during configuration:
    WARNING: Tensor parallel size (4) exceeds available GPUs (1). This will likely cause issues. Consider reducing tensor_parallel_size to 1 or less...
    
  • Enhanced Ray placement error with actionable guidance:
    ValueError: Cannot provide a placement group requiring 4 GPUs (...) within 2550 seconds.
    Tensor parallel size may exceed available GPUs in your cluster. Check resources with `ray status` and `ray list nodes`.
    If running on K8s with limited GPUs, consider reducing --tensor-parallel-size to match available GPU resources.
    

Validation Results

  • Code quality checks passed: pre-commit hooks, format checks, lint checks
  • Backward compatibility preserved: No breaking changes to existing behavior
  • Enhanced user experience: Clear error messages guide users to resolution
  • K8s scenario targeted: Specific guidance for Kubernetes deployment issues

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Adds early GPU count validation and clearer Ray placement error messages
when tensor_parallel_size exceeds available GPUs to address poor
logging and help users diagnose K8s deployment failures.

Signed-off-by: Sairam Pillai <[email protected]>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

@robertgshaw2-redhat
Copy link
Collaborator

instead of improving the log, can we just not allow vllm to start? I don't quite understand why the behavior of falling back to ray is something that is needed

@cjackal
Copy link
Contributor

cjackal commented Sep 27, 2025

I also do think that this silent fallback behavior is not only confusing but also pretty dangerous in the sense that the model server maintainer's small typo in deployment configuration results in complete failure after all the long delay. And as the distribution logic and the user-facing deployment workflow is quite different, users are already well-aware of what distribution backend they are intended to use I think. While it wouldn't be a BC change, I'd +1 on explicit declaration of distribution backend.

(I'm not claiming that this would be considered in this PR; just a feel-ya on @robertgshaw2-redhat 's comment above.)

@sairampillai
Copy link
Author

I agree @robertgshaw2-redhat @cjackal, do you think we should close/merge this PR and then discuss with a wider forum to address the fallback scenario? Or should I go ahead and create a new PR for explicit declaration and early stopping?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Bug]: Poor logging on not enough GPUs for vLLM pod
3 participants