Skip to content

Conversation

mritterfigma
Copy link
Contributor

@mritterfigma mritterfigma commented Mar 17, 2025

Add support for llama 3.2 vision with CUDA graph capture. Removes block on CUDA graph capture for mllama.

Note: Follows a similar approach to the one taken by SGLang, making mllama aware of whether it is in graph capture mode or not (sgl-project/sglang@94cde10)

FIX "Enabling CUDA graph" in #8826 (comment)

Tested:

python3 -m vllm.entrypoints.cli.main serve unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit \
   --quantization="bitsandbytes" \
   --load-format="bitsandbytes" \
   --dtype=bfloat16 \
   --trust_remote_code \
   --gpu-memory-utilization=0.98 \
   --max-model-len=9600 \
   --max-num-seqs 4

started and responded successfully to requests (no enforce-eager). Also verified that it still worked with --enforce-eager. Example command:

curl http://localhost:8000/v1/chat/completions \
  -H "Content-Type: application/json" \
  -d '{
    "model": "unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit",
    "messages": [
      {
        "role": "user",
        "content": [
          {
            "type": "text",
            "text": "What is in this image?"
          },
          {
            "type": "image_url",
            "image_url": {
              "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
            }
          }
        ]
      }
    ],
    "max_tokens": 300
  }'

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Collaborator

@heheda12345 heheda12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great job! Some early feedback. Will check the code more carefully after the tests are passed.

  1. Can you remove the enforce_eager in tests/models/encoder_decoder/vision_language/test_mllama.py and try whether the tests can pass?
  2. Can you use model_config.enforce_eager to check whether we should always run the cross attention layers instead of adding a capture_mode argument?

Signed-off-by: Matt Ritter <[email protected]>
@mritterfigma
Copy link
Contributor Author

mritterfigma commented Mar 17, 2025

Great job! Some early feedback. Will check the code more carefully after the tests are passed.

  1. Can you remove the enforce_eager in tests/models/encoder_decoder/vision_language/test_mllama.py and try whether the tests can pass?
  2. Can you use model_config.enforce_eager to check whether we should always run the cross attention layers instead of adding a capture_mode argument?

Thanks for the quick review!

For (1), I removed enforced_eager and verified that tests pass (pytest tests/models/encoder_decoder/vision_language/test_mllama.py)

For (2), I tried that, but during text-only inference the model crashed. I think the crash happened because cross attention does not work when there are no image inputs (the NoneType we see in the error message is likely related to no images). Specifically avoiding the line of code during capture_mode but not for the duration of the model serving session seems to avoid the issue.

ERROR 03-17 10:42:06 [engine.py:158]   File "/home/ubuntu/fork/vllm/vllm/model_executor/models/mllama.py", line 1001, in forward
ERROR 03-17 10:42:06 [engine.py:158]     hidden_states = self.cross_attn(
ERROR 03-17 10:42:06 [engine.py:158]   File "/home/ubuntu/fork/vllm/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
ERROR 03-17 10:42:06 [engine.py:158]     return self._call_impl(*args, **kwargs)
ERROR 03-17 10:42:06 [engine.py:158]   File "/home/ubuntu/fork/vllm/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
ERROR 03-17 10:42:06 [engine.py:158]     return forward_call(*args, **kwargs)
ERROR 03-17 10:42:06 [engine.py:158]   File "/home/ubuntu/fork/vllm/vllm/model_executor/models/mllama.py", line 870, in forward
ERROR 03-17 10:42:06 [engine.py:158]     output = self.attn(
ERROR 03-17 10:42:06 [engine.py:158]   File "/home/ubuntu/fork/vllm/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
ERROR 03-17 10:42:06 [engine.py:158]     return self._call_impl(*args, **kwargs)
ERROR 03-17 10:42:06 [engine.py:158]   File "/home/ubuntu/fork/vllm/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
ERROR 03-17 10:42:06 [engine.py:158]     return forward_call(*args, **kwargs)
ERROR 03-17 10:42:06 [engine.py:158]   File "/home/ubuntu/fork/vllm/vllm/attention/layer.py", line 214, in forward
ERROR 03-17 10:42:06 [engine.py:158]     torch.ops.vllm.unified_attention_with_output(
ERROR 03-17 10:42:06 [engine.py:158]   File "/home/ubuntu/fork/vllm/venv/lib/python3.10/site-packages/torch/_ops.py", line 1123, in __call__
ERROR 03-17 10:42:06 [engine.py:158]     return self._op(*args, **(kwargs or {}))
ERROR 03-17 10:42:06 [engine.py:158]   File "/home/ubuntu/fork/vllm/vllm/attention/layer.py", line 363, in unified_attention_with_output
ERROR 03-17 10:42:06 [engine.py:158]     self.impl.forward(self,
ERROR 03-17 10:42:06 [engine.py:158]   File "/home/ubuntu/fork/vllm/vllm/attention/backends/flash_attn.py", line 753, in forward
ERROR 03-17 10:42:06 [engine.py:158]     key = key[:num_prefill_kv_tokens]
ERROR 03-17 10:42:06 [engine.py:158] TypeError: 'NoneType' object is not subscriptable

@heheda12345
Copy link
Collaborator

(1) That's great to know. Thanks!
(2) I just notice that we have max_encoder_seq_len in both FlashAttentionMetadata and XFormersMetadata. Using these arguments to replace max(attn_metadata.encoder_seq_lens) == 0 should no longer break cuda graph capture.
(3) Given the special of text-only input, @sroy745 can the cuda graph captured for cross attention kernels with both encoder tokens and decoder tokens be used for a batch with no encoder tokens?

@mritterfigma mritterfigma force-pushed the llama-vision-cuda-graph branch from 50ac218 to 8876572 Compare March 19, 2025 21:23
@mritterfigma
Copy link
Contributor Author

(1) That's great to know. Thanks! (2) I just notice that we have max_encoder_seq_len in both FlashAttentionMetadata and XFormersMetadata. Using these arguments to replace max(attn_metadata.encoder_seq_lens) == 0 should no longer break cuda graph capture. (3) Given the special of text-only input, @sroy745 can the cuda graph captured for cross attention kernels with both encoder tokens and decoder tokens be used for a batch with no encoder tokens?

Switching to max_encoder_seq_len works! That removes the need for introduced capture_mode, which is nice. Thanks for the suggestion.

I'm not sure about the answer to (3). I know the model works with text-only, image-only, and text-image. But, I'm not sure how to tell if the CUDA graph is actually being used. I suppose we may be able to tell with some load test, but I'm not sure if we have any existing tooling for that in vllm

@sroy745
Copy link
Collaborator

sroy745 commented Mar 19, 2025

(1) That's great to know. Thanks! (2) I just notice that we have max_encoder_seq_len in both FlashAttentionMetadata and XFormersMetadata. Using these arguments to replace max(attn_metadata.encoder_seq_lens) == 0 should no longer break cuda graph capture. (3) Given the special of text-only input, @sroy745 can the cuda graph captured for cross attention kernels with both encoder tokens and decoder tokens be used for a batch with no encoder tokens?

Switching to max_encoder_seq_len works! That removes the need for introduced capture_mode, which is nice. Thanks for the suggestion.

I'm not sure about the answer to (3). I know the model works with text-only, image-only, and text-image. But, I'm not sure how to tell if the CUDA graph is actually being used. I suppose we may be able to tell with some load test, but I'm not sure if we have any existing tooling for that in vllm

If the tests are passing with enforce_eager=False it should be working fine with CudaGraph assuming we have test cases with text-only, image-only and text-image combinations?
You should be able to run the load tests with the server running in enforce_eager=True/False mode. I think the load tests (use sonnet/sharedgpt dataset) should be able to test text-only but not sure about image-only and text-image combinations.

Copy link
Collaborator

@heheda12345 heheda12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sroy745 Thanks for your advice.
I've tested this PR on my local machine. All tests in test_mllama.py can pass except the following two.

FAILED test_mllama.py::test_models_interleaved_images[_Backend.XFORMERS-5-128-bfloat16-meta-llama/Llama-3.2-11B-Vision-Instruct] - AttributeError: 'list' object has no attribute 'shape'
FAILED test_mllama.py::test_models_interleaved_images[_Backend.FLASH_ATTN-5-128-bfloat16-meta-llama/Llama-3.2-11B-Vision-Instruct] - AttributeError: 'list' object has no attribute 'shape'

I think this problem will be fixed by #14883, so I approve this PR. @mritterfigma Thanks for your contribution.

@vllm-bot vllm-bot merged commit a8652f4 into vllm-project:main Mar 20, 2025
13 checks passed
gshtras added a commit to ROCm/vllm that referenced this pull request Mar 24, 2025
gshtras added a commit to ROCm/vllm that referenced this pull request Mar 24, 2025
erictang000 pushed a commit to erictang000/vllm that referenced this pull request Mar 25, 2025
vllm-bot pushed a commit that referenced this pull request Apr 4, 2025
Alex4210987 pushed a commit to LeiWang1999/vllm-bitblas that referenced this pull request Apr 5, 2025
lulmer pushed a commit to lulmer/vllm that referenced this pull request Apr 7, 2025
lulmer pushed a commit to lulmer/vllm that referenced this pull request Apr 7, 2025
lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Apr 29, 2025
shreyankg pushed a commit to shreyankg/vllm that referenced this pull request May 3, 2025
shreyankg pushed a commit to shreyankg/vllm that referenced this pull request May 3, 2025
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants