Skip to content
Merged
1 change: 1 addition & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ steps:
- python3 offline_inference_with_prefix.py
- python3 llm_engine_example.py
- python3 offline_inference_vision_language.py
- python3 offline_inference_vision_language_multi_image.py
- python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
- python3 offline_inference_encoder_decoder.py

Expand Down
2 changes: 1 addition & 1 deletion docs/source/getting_started/debugging.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ If you have already taken care of the above issues, but the vLLM instance still

With more logging, hopefully you can find the root cause of the issue.

If it crashes, and the error trace shows somewhere around ``self.graph.replay()`` in ``vllm/worker/model_runner.py``, it is a cuda error inside cudagraph. To know the particular cuda operation that causes the error, you can add ``--enforce-eager`` to the command line, or ``enforce_eager=True`` to the ``LLM`` class, to disable the cudagraph optimization. This way, you can locate the exact cuda operation that causes the error.
If it crashes, and the error trace shows somewhere around ``self.graph.replay()`` in ``vllm/worker/model_runner.py``, it is a cuda error inside cudagraph. To know the particular cuda operation that causes the error, you can add ``--enforce-eager`` to the command line, or ``enforce_eager=True`` to the :class:`~vllm.LLM` class, to disable the cudagraph optimization. This way, you can locate the exact cuda operation that causes the error.

Here are some common issues that can cause hangs:

Expand Down
6 changes: 4 additions & 2 deletions docs/source/getting_started/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ Offline Batched Inference

We first show an example of using vLLM for offline batched inference on a dataset. In other words, we use vLLM to generate texts for a list of input prompts.

Import ``LLM`` and ``SamplingParams`` from vLLM. The ``LLM`` class is the main class for running offline inference with vLLM engine. The ``SamplingParams`` class specifies the parameters for the sampling process.
Import :class:`~vllm.LLM` and :class:`~vllm.SamplingParams` from vLLM.
The :class:`~vllm.LLM` class is the main class for running offline inference with vLLM engine.
The :class:`~vllm.SamplingParams` class specifies the parameters for the sampling process.

.. code-block:: python

Expand All @@ -42,7 +44,7 @@ Define the list of input prompts and the sampling parameters for generation. The
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

Initialize vLLM's engine for offline inference with the ``LLM`` class and the `OPT-125M model <https://arxiv.org/abs/2205.01068>`_. The list of supported models can be found at :ref:`supported models <supported_models>`.
Initialize vLLM's engine for offline inference with the :class:`~vllm.LLM` class and the `OPT-125M model <https://arxiv.org/abs/2205.01068>`_. The list of supported models can be found at :ref:`supported models <supported_models>`.

.. code-block:: python

Expand Down
21 changes: 12 additions & 9 deletions docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -194,12 +194,12 @@ Multimodal Language Models

* - Architecture
- Models
- Supported Modalities
- Modalities
- Example HuggingFace Models
- :ref:`LoRA <lora>`
* - :code:`Blip2ForConditionalGeneration`
- BLIP-2
- Image
- Image\ :sup:`E`
- :code:`Salesforce/blip2-opt-2.7b`, :code:`Salesforce/blip2-opt-6.7b`, etc.
-
* - :code:`ChameleonForConditionalGeneration`
Expand All @@ -214,40 +214,43 @@ Multimodal Language Models
-
* - :code:`InternVLChatModel`
- InternVL2
- Image
- Image\ :sup:`E`
- :code:`OpenGVLab/InternVL2-4B`, :code:`OpenGVLab/InternVL2-8B`, etc.
-
* - :code:`LlavaForConditionalGeneration`
- LLaVA-1.5
- Image
- Image\ :sup:`E`
- :code:`llava-hf/llava-1.5-7b-hf`, :code:`llava-hf/llava-1.5-13b-hf`, etc.
-
* - :code:`LlavaNextForConditionalGeneration`
- LLaVA-NeXT
- Image
- Image\ :sup:`E+`
- :code:`llava-hf/llava-v1.6-mistral-7b-hf`, :code:`llava-hf/llava-v1.6-vicuna-7b-hf`, etc.
-
* - :code:`PaliGemmaForConditionalGeneration`
- PaliGemma
- Image
- Image\ :sup:`E`
- :code:`google/paligemma-3b-pt-224`, :code:`google/paligemma-3b-mix-224`, etc.
-
* - :code:`Phi3VForCausalLM`
- Phi-3-Vision, Phi-3.5-Vision
- Image
- Image\ :sup:`E+`
- :code:`microsoft/Phi-3-vision-128k-instruct`, :code:`microsoft/Phi-3.5-vision-instruct` etc.
-
* - :code:`MiniCPMV`
- MiniCPM-V
- Image
- Image\ :sup:`+`
- :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, :code:`openbmb/MiniCPM-V-2_6`, etc.
-
* - :code:`UltravoxModel`
- Ultravox
- Audio
- Audio\ :sup:`E+`
- :code:`fixie-ai/ultravox-v0_3`
-

| :sup:`E` Pre-computed embeddings can be inputted for this modality.
| :sup:`+` Multiple items can be inputted per text prompt for this modality.

.. note::
For :code:`openbmb/MiniCPM-V-2`, the official repo doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now.
For more details, please see: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630
Expand Down
123 changes: 88 additions & 35 deletions docs/source/models/vlm.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,23 @@ This document shows you how to run and serve these models using vLLM.
.. important::
We are actively iterating on VLM support. Expect breaking changes to VLM usage and development in upcoming releases without prior deprecation.

Currently, the support for vision language models on vLLM has the following limitations:

* Only single image input is supported per text prompt.

We are continuously improving user & developer experience for VLMs. Please `open an issue on GitHub <https://github.com/vllm-project/vllm/issues/new/choose>`_ if you have any feedback or feature requests.

Offline Batched Inference
-------------------------
Offline Inference
-----------------

Single-image input
^^^^^^^^^^^^^^^^^^

To initialize a VLM, the aforementioned arguments must be passed to the ``LLM`` class for instantiating the engine.
The :class:`~vllm.LLM` class can be instantiated in much the same way as language-only models.

.. code-block:: python

llm = LLM(model="llava-hf/llava-1.5-7b-hf")

.. important::
.. note::
We have removed all vision language related CLI args in the ``0.5.1`` release. **This is a breaking change**, so please update your code to follow
the above snippet. Specifically, ``image_feature_size`` is no longer required to be specified as we now calculate that
internally for each model.

the above snippet. Specifically, ``image_feature_size`` can no longer be specified as we now calculate that internally for each model.

To pass an image to the model, note the following in :class:`vllm.inputs.PromptInputs`:

Expand Down Expand Up @@ -86,61 +83,117 @@ To pass an image to the model, note the following in :class:`vllm.inputs.PromptI

A code example can be found in `examples/offline_inference_vision_language.py <https://github.com/vllm-project/vllm/blob/main/examples/offline_inference_vision_language.py>`_.

Multi-image input
^^^^^^^^^^^^^^^^^

Online OpenAI Vision API Compatible Inference
----------------------------------------------
Multi-image input is only supported for a subset of VLMs, as shown :ref:`here <supported_vlms>`.

You can serve vision language models with vLLM's HTTP server that is compatible with `OpenAI Vision API <https://platform.openai.com/docs/guides/vision>`_.
To enable multiple multi-modal items per text prompt, you have to set ``limit_mm_per_prompt`` for the :class:`~vllm.LLM` class.

.. note::
Currently, vLLM supports only **single** ``image_url`` input per ``messages``. Support for multi-image inputs will be
added in the future.
.. code-block:: python

Below is an example on how to launch the same ``llava-hf/llava-1.5-7b-hf`` with vLLM API server.
llm = LLM(
model="microsoft/Phi-3.5-vision-instruct",
trust_remote_code=True, # Required to load Phi-3.5-vision
max_model_len=4096, # Otherwise, it may not fit in smaller GPUs
limit_mm_per_prompt={"image": 2}, # The maximum number to accept
)

.. important::
Since OpenAI Vision API is based on `Chat <https://platform.openai.com/docs/api-reference/chat>`_ API, a chat template
is **required** to launch the API server if the model's tokenizer does not come with one. In this example, we use the
HuggingFace Llava chat template that you can find in the example folder `here <https://github.com/vllm-project/vllm/blob/main/examples/template_llava.jinja>`_.
Instead of passing in a single image, you can pass in a list of images.

.. code-block:: python

# Refer to the HuggingFace repo for the correct format to use
prompt = "<|user|>\n<image_1>\n<image_2>\nWhat is the content of each image?<|end|>\n<|assistant|>\n"

# Load the images using PIL.Image
image1 = PIL.Image.open(...)
image2 = PIL.Image.open(...)

outputs = llm.generate({
"prompt": prompt,
"multi_modal_data": {
"image": [image1, image2]
},
})

for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)

A code example can be found in `examples/offline_inference_vision_language_multi_image.py <https://github.com/vllm-project/vllm/blob/main/examples/offline_inference_vision_language_multi_image.py>`_.

Online Inference
----------------

OpenAI Vision API
^^^^^^^^^^^^^^^^^

You can serve vision language models with vLLM's HTTP server that is compatible with `OpenAI Vision API <https://platform.openai.com/docs/guides/vision>`_.

Below is an example on how to launch the same ``microsoft/Phi-3.5-vision-instruct`` with vLLM's OpenAI-compatible API server.

.. code-block:: bash

vllm serve llava-hf/llava-1.5-7b-hf --chat-template template_llava.jinja
vllm serve microsoft/Phi-3.5-vision-instruct --max-model-len 4096 \
--trust-remote-code --limit-mm-per-prompt image=2

.. important::
We have removed all vision language related CLI args in the ``0.5.1`` release. **This is a breaking change**, so please update your code to follow
the above snippet. Specifically, ``image_feature_size`` is no longer required to be specified as we now calculate that
internally for each model.
Since OpenAI Vision API is based on `Chat Completions <https://platform.openai.com/docs/api-reference/chat>`_ API,
a chat template is **required** to launch the API server.

Although Phi-3.5-Vision comes with a chat template, for other models you may have to provide one if the model's tokenizer does not come with it.
The chat template can be inferred based on the documentation on the model's HuggingFace repo.
For example, LLaVA-1.5 (``llava-hf/llava-1.5-7b-hf``) requires a chat template that can be found `here <https://github.com/vllm-project/vllm/blob/main/examples/template_llava.jinja>`_.

To consume the server, you can use the OpenAI client like in the example below:

.. code-block:: python

from openai import OpenAI

openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1"

client = OpenAI(
api_key=openai_api_key,
base_url=openai_api_base,
)

# Single-image input inference
image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"

chat_response = client.chat.completions.create(
model="llava-hf/llava-1.5-7b-hf",
model="microsoft/Phi-3.5-vision-instruct",
messages=[{
"role": "user",
"content": [
# NOTE: The prompt formatting with the image token `<image>` is not needed
# since the prompt will be processed automatically by the API server.
{"type": "text", "text": "What's in this image?"},
{
"type": "image_url",
"image_url": {
"url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg",
},
},
{"type": "text", "text": "What’s in this image?"},
{"type": "image_url", "image_url": {"url": image_url}},
],
}],
)
print("Chat response:", chat_response)
print("Chat completion output:", chat_response.choices[0].message.content)

# Multi-image input inference
image_url_duck = "https://upload.wikimedia.org/wikipedia/commons/d/da/2015_Kaczka_krzy%C5%BCowka_w_wodzie_%28samiec%29.jpg"
image_url_lion = "https://upload.wikimedia.org/wikipedia/commons/7/77/002_The_lion_king_Snyggve_in_the_Serengeti_National_Park_Photo_by_Giles_Laurent.jpg"

chat_response = client.chat.completions.create(
model="microsoft/Phi-3.5-vision-instruct",
messages=[{
"role": "user",
"content": [
{"type": "text", "text": "What are the animals in these images?"},
{"type": "image_url", "image_url": {"url": image_url_duck}},
{"type": "image_url", "image_url": {"url": image_url_lion}},
],
}],
)
print("Chat completion output:", chat_response.choices[0].message.content)


A full code example can be found in `examples/openai_vision_api_client.py <https://github.com/vllm-project/vllm/blob/main/examples/openai_vision_api_client.py>`_.

Expand Down
95 changes: 95 additions & 0 deletions examples/offline_inference_vision_language_multi_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""
This example shows how to use vLLM for running offline inference with
multi-image input on vision language models, using the chat template defined
by the model.
"""
from argparse import Namespace
from typing import List

from vllm import LLM
from vllm.multimodal.utils import fetch_image
from vllm.utils import FlexibleArgumentParser

QUESTION = "What is the content of each image?"
IMAGE_URLS = [
"https://upload.wikimedia.org/wikipedia/commons/d/da/2015_Kaczka_krzy%C5%BCowka_w_wodzie_%28samiec%29.jpg",
"https://upload.wikimedia.org/wikipedia/commons/7/77/002_The_lion_king_Snyggve_in_the_Serengeti_National_Park_Photo_by_Giles_Laurent.jpg",
]


def _load_phi3v(image_urls: List[str]):
return LLM(
model="microsoft/Phi-3.5-vision-instruct",
trust_remote_code=True,
max_model_len=4096,
limit_mm_per_prompt={"image": len(image_urls)},
)


def run_phi3v_generate(question: str, image_urls: List[str]):
llm = _load_phi3v(image_urls)

placeholders = "\n".join(f"<|image_{i}|>"
for i, _ in enumerate(image_urls, start=1))
prompt = f"<|user|>\n{placeholders}\n{question}<|end|>\n<|assistant|>\n"

outputs = llm.generate({
"prompt": prompt,
"multi_modal_data": {
"image": [fetch_image(url) for url in image_urls]
},
})

for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)


def run_phi3v_chat(question: str, image_urls: List[str]):
llm = _load_phi3v(image_urls)

outputs = llm.chat([{
"role":
"user",
"content": [
{
"type": "text",
"text": question,
},
*({
"type": "image_url",
"image_url": {
"url": image_url
},
} for image_url in image_urls),
],
}])

for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)


def main(args: Namespace):
method = args.method

if method == "generate":
run_phi3v_generate(QUESTION, IMAGE_URLS)
elif method == "chat":
run_phi3v_chat(QUESTION, IMAGE_URLS)
else:
raise ValueError(f"Invalid method: {method}")


if __name__ == "__main__":
parser = FlexibleArgumentParser(
description='Demo on using vLLM for offline inference with '
'vision language models that support multi-image input')
parser.add_argument("--method",
type=str,
default="generate",
choices=["generate", "chat"],
help="The method to run in `vllm.LLM`.")

args = parser.parse_args()
main(args)
Loading