Skip to content

Commit 57f09a4

Browse files
[Hardware][Intel] OpenVINO vLLM backend (#5379)
1 parent 5932634 commit 57f09a4

File tree

22 files changed

+1393
-23
lines changed

22 files changed

+1393
-23
lines changed

.buildkite/run-openvino-test.sh

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# This script build the OpenVINO docker image and run the offline inference inside the container.
2+
# It serves a sanity check for compilation and basic model usage.
3+
set -ex
4+
5+
# Try building the docker image
6+
docker build -t openvino-test -f Dockerfile.openvino .
7+
8+
# Setup cleanup
9+
remove_docker_container() { docker rm -f openvino-test || true; }
10+
trap remove_docker_container EXIT
11+
remove_docker_container
12+
13+
# Run the image and launch offline inference
14+
docker run --network host --env VLLM_OPENVINO_KVCACHE_SPACE=1 --name openvino-test openvino-test python3 /workspace/vllm/examples/offline_inference.py

Dockerfile.openvino

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# The vLLM Dockerfile is used to construct vLLM image that can be directly used
2+
# to run the OpenAI compatible server.
3+
4+
FROM ubuntu:22.04 AS dev
5+
6+
RUN apt-get update -y && \
7+
apt-get install -y python3-pip git
8+
WORKDIR /workspace
9+
10+
# copy requirements
11+
COPY requirements-build.txt /workspace/vllm/
12+
COPY requirements-common.txt /workspace/vllm/
13+
COPY requirements-openvino.txt /workspace/vllm/
14+
15+
COPY vllm/ /workspace/vllm/vllm
16+
COPY setup.py /workspace/vllm/
17+
18+
# install build requirements
19+
RUN PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" python3 -m pip install -r /workspace/vllm/requirements-build.txt
20+
# build vLLM with OpenVINO backend
21+
RUN PIP_PRE=1 PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu https://storage.openvinotoolkit.org/simple/wheels/nightly/" VLLM_TARGET_DEVICE="openvino" python3 -m pip install /workspace/vllm/
22+
23+
COPY examples/ /workspace/vllm/examples
24+
COPY benchmarks/ /workspace/vllm/benchmarks
25+
26+
CMD ["/bin/bash"]

benchmarks/benchmark_latency.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -207,9 +207,10 @@ def run_to_completion(profile_dir: Optional[str] = None):
207207
parser.add_argument(
208208
"--device",
209209
type=str,
210-
default="cuda",
211-
choices=["cuda", "cpu", "tpu", "xpu"],
212-
help='device type for vLLM execution, supporting CUDA and CPU.')
210+
default="auto",
211+
choices=["auto", "cuda", "cpu", "openvino", "tpu", "xpu"],
212+
help='device type for vLLM execution, supporting CUDA, OpenVINO and '
213+
'CPU.')
213214
parser.add_argument('--block-size',
214215
type=int,
215216
default=16,

benchmarks/benchmark_throughput.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -349,9 +349,10 @@ def main(args: argparse.Namespace):
349349
parser.add_argument(
350350
"--device",
351351
type=str,
352-
default="cuda",
353-
choices=["cuda", "cpu", "tpu", "xpu"],
354-
help='device type for vLLM execution, supporting CUDA and CPU.')
352+
default="auto",
353+
choices=["auto", "cuda", "cpu", "openvino", "tpu", "xpu"],
354+
help='device type for vLLM execution, supporting CUDA, OpenVINO and '
355+
'CPU.')
355356
parser.add_argument(
356357
"--enable-prefix-caching",
357358
action='store_true',
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
.. _installation_openvino:
2+
3+
Installation with OpenVINO
4+
==========================
5+
6+
vLLM powered by OpenVINO supports all LLM models from :doc:`vLLM supported models list <../models/supported_models>` and can perform optimal model serving on all x86-64 CPUs with, at least, AVX2 support. OpenVINO vLLM backend supports the following advanced vLLM features:
7+
8+
- Prefix caching (``--enable-prefix-caching``)
9+
- Chunked prefill (``--enable-chunked-prefill``)
10+
11+
**Table of contents**:
12+
13+
- :ref:`Requirements <openvino_backend_requirements>`
14+
- :ref:`Quick start using Dockerfile <openvino_backend_quick_start_dockerfile>`
15+
- :ref:`Build from source <install_openvino_backend_from_source>`
16+
- :ref:`Performance tips <openvino_backend_performance_tips>`
17+
- :ref:`Limitations <openvino_backend_limitations>`
18+
19+
.. _openvino_backend_requirements:
20+
21+
Requirements
22+
------------
23+
24+
* OS: Linux
25+
* Instruction set architecture (ISA) requirement: at least AVX2.
26+
27+
.. _openvino_backend_quick_start_dockerfile:
28+
29+
Quick start using Dockerfile
30+
----------------------------
31+
32+
.. code-block:: console
33+
34+
$ docker build -f Dockerfile.openvino -t vllm-openvino-env .
35+
$ docker run -it --rm vllm-openvino-env
36+
37+
.. _install_openvino_backend_from_source:
38+
39+
Install from source
40+
-------------------
41+
42+
- First, install Python. For example, on Ubuntu 22.04, you can run:
43+
44+
.. code-block:: console
45+
46+
$ sudo apt-get update -y
47+
$ sudo apt-get install python3
48+
49+
- Second, install prerequisites vLLM OpenVINO backend installation:
50+
51+
.. code-block:: console
52+
53+
$ pip install --upgrade pip
54+
$ pip install -r requirements-build.txt --extra-index-url https://download.pytorch.org/whl/cpu
55+
56+
- Finally, install vLLM with OpenVINO backend:
57+
58+
.. code-block:: console
59+
60+
$ PIP_PRE=1 PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu https://storage.openvinotoolkit.org/simple/wheels/nightly/" VLLM_TARGET_DEVICE=openvino python -m pip install -v .
61+
62+
.. _openvino_backend_performance_tips:
63+
64+
Performance tips
65+
----------------
66+
67+
vLLM OpenVINO backend uses the following environment variables to control behavior:
68+
69+
- ``VLLM_OPENVINO_KVCACHE_SPACE`` to specify the KV Cache size (e.g, ``VLLM_OPENVINO_KVCACHE_SPACE=40`` means 40 GB space for KV cache), larger setting will allow vLLM running more requests in parallel. This parameter should be set based on the hardware configuration and memory management pattern of users.
70+
71+
- ``VLLM_OPENVINO_CPU_KV_CACHE_PRECISION=u8`` to control KV cache precision. By default, FP16 / BF16 is used depending on platform.
72+
73+
- ``VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS=ON`` to enable U8 weights compression during model loading stage. By default, compression is turned off.
74+
75+
To enable better TPOT / TTFT latency, you can use vLLM's chunked prefill feature (``--enable-chunked-prefill``). Based on the experiments, the recommended batch size is ``256`` (``--max-num-batched-tokens``)
76+
77+
OpenVINO best known configuration is:
78+
79+
.. code-block:: console
80+
81+
$ VLLM_OPENVINO_KVCACHE_SPACE=100 VLLM_OPENVINO_CPU_KV_CACHE_PRECISION=u8 VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS=ON \
82+
python3 vllm/benchmarks/benchmark_throughput.py --model meta-llama/Llama-2-7b-chat-hf --dataset vllm/benchmarks/ShareGPT_V3_unfiltered_cleaned_split.json --enable-chunked-prefill --max-num-batched-tokens 256
83+
84+
.. _openvino_backend_limitations:
85+
86+
Limitations
87+
-----------
88+
89+
- LoRA serving is not supported.
90+
91+
- Only LLM models are currently supported. LLaVa and encoder-decoder models are not currently enabled in vLLM OpenVINO integration.
92+
93+
- Tensor and pipeline parallelism are not currently enabled in vLLM integration.
94+
95+
- Speculative sampling is not tested within vLLM integration.

docs/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ Documentation
6363

6464
getting_started/installation
6565
getting_started/amd-installation
66+
getting_started/openvino-installation
6667
getting_started/cpu-installation
6768
getting_started/neuron-installation
6869
getting_started/tpu-installation

requirements-openvino.txt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# Common dependencies
2+
-r requirements-common.txt
3+
4+
# OpenVINO dependencies
5+
torch >= 2.1.2
6+
openvino ~= 2024.3.0.dev
7+
optimum-intel[openvino] >= 1.17.2
8+
9+
triton >= 2.2.0 # FIXME(woosuk): This is a hack to avoid import error.

setup.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,10 @@ def _is_cpu() -> bool:
233233
return VLLM_TARGET_DEVICE == "cpu"
234234

235235

236+
def _is_openvino() -> bool:
237+
return VLLM_TARGET_DEVICE == "openvino"
238+
239+
236240
def _is_xpu() -> bool:
237241
return VLLM_TARGET_DEVICE == "xpu"
238242

@@ -337,6 +341,8 @@ def get_vllm_version() -> str:
337341
if neuron_version != MAIN_CUDA_VERSION:
338342
neuron_version_str = neuron_version.replace(".", "")[:3]
339343
version += f"+neuron{neuron_version_str}"
344+
elif _is_openvino():
345+
version += "+openvino"
340346
elif _is_tpu():
341347
version += "+tpu"
342348
elif _is_cpu():
@@ -388,6 +394,8 @@ def _read_requirements(filename: str) -> List[str]:
388394
requirements = _read_requirements("requirements-rocm.txt")
389395
elif _is_neuron():
390396
requirements = _read_requirements("requirements-neuron.txt")
397+
elif _is_openvino():
398+
requirements = _read_requirements("requirements-openvino.txt")
391399
elif _is_tpu():
392400
requirements = _read_requirements("requirements-tpu.txt")
393401
elif _is_cpu():
@@ -396,7 +404,8 @@ def _read_requirements(filename: str) -> List[str]:
396404
requirements = _read_requirements("requirements-xpu.txt")
397405
else:
398406
raise ValueError(
399-
"Unsupported platform, please use CUDA, ROCm, Neuron, or CPU.")
407+
"Unsupported platform, please use CUDA, ROCm, Neuron, "
408+
"OpenVINO, or CPU.")
400409
return requirements
401410

402411

tests/kernels/test_attention_selector.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99

1010

1111
@pytest.mark.parametrize(
12-
"name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER"])
13-
@pytest.mark.parametrize("device", ["cpu", "hip", "cuda"])
12+
"name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER", "OPENVINO"])
13+
@pytest.mark.parametrize("device", ["cpu", "openvino", "hip", "cuda"])
1414
def test_env(name: str, device: str, monkeypatch):
1515
"""Test that the attention selector can be set via environment variable.
1616
Note that we do not test FlashAttn because it is the default backend.
@@ -28,6 +28,11 @@ def test_env(name: str, device: str, monkeypatch):
2828
backend = which_attn_to_use(8, 16, 8, None, torch.float16,
2929
torch.float16, 16)
3030
assert backend.name == "ROCM_FLASH"
31+
elif device == "openvino":
32+
with patch("vllm.attention.selector.is_openvino", return_value=True):
33+
backend = which_attn_to_use(8, 16, 8, None, torch.float16,
34+
torch.float16, 16)
35+
assert backend.name == "OPENVINO"
3136
else:
3237
backend = which_attn_to_use(8, 16, 8, None, torch.float16,
3338
torch.float16, 16)
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
from dataclasses import dataclass
2+
from typing import List, Tuple
3+
4+
import openvino as ov
5+
import torch
6+
7+
from vllm.attention.backends.abstract import (AttentionBackend,
8+
AttentionMetadata)
9+
10+
11+
class OpenVINOAttentionBackend(AttentionBackend):
12+
13+
@staticmethod
14+
def get_name() -> str:
15+
return "openvino"
16+
17+
@staticmethod
18+
def get_impl_cls():
19+
# OpenVINO implements PagedAttention as part of the Optimum
20+
# exported model
21+
raise NotImplementedError
22+
23+
@staticmethod
24+
def make_metadata(*args, **kwargs) -> "AttentionMetadata":
25+
raise NotImplementedError
26+
27+
@staticmethod
28+
def make_openvino_metadata(*args, **kwargs) -> "OpenVINOAttentionMetadata":
29+
return OpenVINOAttentionMetadata(*args, **kwargs)
30+
31+
@staticmethod
32+
def get_kv_cache_shape(
33+
num_blocks: int,
34+
block_size: int,
35+
num_kv_heads: int,
36+
head_size: int,
37+
) -> Tuple[int, ...]:
38+
return (2, num_blocks, num_kv_heads, block_size, head_size)
39+
40+
@staticmethod
41+
def swap_blocks(
42+
src_kv_cache: ov.Tensor,
43+
dst_kv_cache: ov.Tensor,
44+
src_to_dst: torch.Tensor,
45+
) -> None:
46+
# OpenVINO currently supports only CPU, which does not require
47+
# swap of KV cache blocks
48+
raise NotImplementedError
49+
50+
@staticmethod
51+
def copy_blocks(
52+
kv_caches: List[Tuple[ov.Tensor, ov.Tensor]],
53+
src_to_dists: List[Tuple[int, int]],
54+
) -> None:
55+
for src, dst in src_to_dists:
56+
for key_cache, value_cache in kv_caches:
57+
key_cache.data[dst, :] = key_cache.data[src, :]
58+
value_cache.data[dst, :] = value_cache.data[src, :]
59+
60+
61+
@dataclass
62+
class OpenVINOAttentionMetadata:
63+
"""Metadata for OpenVINOAttentionBackend.
64+
65+
Basic terms used below:
66+
- batch_size_in_sequences - total number of sequences to execute​
67+
- prompt_lens – per sequence size number of scheduled tokens​
68+
- batch_size_in_tokens = sum(prompt_lens)​
69+
- max_context_len = max(context_lens)​
70+
- max_num_blocks = div_up(max_context_len / BLOCK_SIZE)​
71+
- num_blocks – total number of blocks in block_indices​
72+
"""
73+
74+
# Describes past KV cache size for each sequence within a batch
75+
# Shape: [batch_size_in_sequences]
76+
# Type: i32​
77+
past_lens: torch.Tensor
78+
79+
# Describes start indices of input / speculative tokens from
80+
# current sequences within a batch sequence​
81+
# Shape: [batch_size_in_sequences + 1]​
82+
# Type: i32
83+
subsequence_begins: torch.Tensor
84+
85+
# Describes block tables for each sequence within a batch​ -
86+
# indices along 0th dimension in key_cache and value_cache inputs​
87+
# Shape: [num_blocks]
88+
# Type: i32​
89+
block_indices: torch.Tensor
90+
91+
# Describes block tables for each sequence within a batch​ -
92+
# for i-th element, it is an index in block_indices with the
93+
# first block belonging to i-th sequence​
94+
# Shape: [batch_size_in_sequences + 1]
95+
# Type: i32​
96+
block_indices_begins: torch.Tensor
97+
98+
# Describes max context length
99+
# Shape: scalar
100+
# Type: i32
101+
max_context_len: torch.Tensor

0 commit comments

Comments
 (0)