Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 1 addition & 7 deletions docs/source/speeding_up_training.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,7 @@ To speed up generation, you can use [vLLM](https://github.com/vllm-project/vllm)
To use [vLLM](https://github.com/vllm-project/vllm), first install it using:

```bash
pip install vllm
```

or

```bash
pip install "trl[vllm]"
pip install trl[vllm]
```

<hfoptions id="vllm examples">
Expand Down
6 changes: 6 additions & 0 deletions docs/source/vllm_integration.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@

This document will guide you through the process of using vLLM with TRL for faster generation in online methods like GRPO and Online DPO. We first summarize a tl;dr on how to use vLLM with TRL, and then we will go into the details of how it works under the hood. Let's go! 🔥

<Tip warning={true}>

TRL currently only supports vLLM versions `0.10.0`, `0.10.1`, and `0.10.2`. Please ensure you have one of these versions installed to avoid compatibility issues.

</Tip>

## 🚀 How can I use vLLM with TRL to speed up training?

💡 **Note**: Resources required for this specific example: a single node with 8 GPUs.
Expand Down
3 changes: 1 addition & 2 deletions examples/scripts/evals/judge_tldr.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@

# /// script
# dependencies = [
# "trl",
# "vllm",
# "trl[vllm]",
# ]
# ///

Expand Down
3 changes: 1 addition & 2 deletions examples/scripts/rloo.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,11 @@

# /// script
# dependencies = [
# "trl",
# "trl[vllm]",
# "peft",
# "math-verify",
# "latex2sympy2_extended",
# "trackio",
# "vllm",
# "kernels",
# ]
# ///
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ test =
vllm =
# vLLM package does not yet support Python 3.13. These constraints can be lifted once support is added:
# see https://github.com/vllm-project/vllm/pull/13164
vllm>=0.10.0; python_version < "3.13"
vllm>=0.10.0,<=0.10.2; python_version < "3.13"
fastapi; python_version < "3.13"
pydantic; python_version < "3.13"
requests; python_version < "3.13"
Expand Down
2 changes: 1 addition & 1 deletion trl/extras/vllm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def __init__(
if not is_requests_available():
raise ImportError("requests is not installed. Please install it with `pip install requests`.")
if not is_vllm_available():
raise ImportError("vLLM is not installed. Please install it with `pip install vllm`.")
raise ImportError("vLLM is not installed. Please install it with `pip install trl[vllm]`.")

self.session = requests.Session()

Expand Down
12 changes: 11 additions & 1 deletion trl/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import importlib
import os
import warnings
from itertools import chain
from types import ModuleType
from typing import Any
Expand All @@ -35,7 +36,7 @@
_requests_available = _is_package_available("requests")
_unsloth_available = _is_package_available("unsloth")
_uvicorn_available = _is_package_available("uvicorn")
_vllm_available = _is_package_available("vllm")
_vllm_available, _vllm_version = _is_package_available("vllm", return_version=True)
_vllm_ascend_available = _is_package_available("vllm_ascend")
_weave_available = _is_package_available("weave")

Expand Down Expand Up @@ -81,6 +82,15 @@ def is_uvicorn_available() -> bool:


def is_vllm_available() -> bool:
if _vllm_available and (
version.parse(_vllm_version) < version.parse("0.10.0")
or version.parse(_vllm_version) > version.parse("0.10.2")
):
warnings.warn(
"TRL currently only supports vLLM versions `0.10.0`, `0.10.1`, and `0.10.2`. You have version "
f"{_vllm_version} installed. We recommend to install one of these versions to avoid compatibility issues.",
UserWarning,
)
return _vllm_available


Expand Down
2 changes: 1 addition & 1 deletion trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@ def __init__(
if not is_vllm_available():
raise ImportError(
"vLLM is not available and `use_vllm` is set to True. Please install vLLM with "
"`pip install vllm` to use it."
"`pip install [vllm]` to use it."
)

if self.vllm_mode == "server":
Expand Down
2 changes: 1 addition & 1 deletion trl/trainer/online_dpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ class may differ from those in [`~transformers.TrainingArguments`].
default=False,
metadata={
"help": "Whether to use vLLM for generating completions. Requires vLLM to be installed "
"(`pip install vllm`)."
"(`pip install trl[vllm]`)."
},
)
vllm_model_impl: str = field(
Expand Down
2 changes: 1 addition & 1 deletion trl/trainer/online_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ def __init__(
if not is_vllm_available():
raise ImportError(
"vLLM is not available and `use_vllm` is set to True. Please install vLLM with "
"`pip install vllm` to use it."
"`pip install trl[vllm]` to use it."
)

if self.vllm_mode == "server":
Expand Down
2 changes: 1 addition & 1 deletion trl/trainer/rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,7 @@ def decode(example, tokenizer):
if not is_vllm_available():
raise ImportError(
"vLLM is not available and `use_vllm` is set to True. Please install vLLM with "
"`pip install vllm` to use it."
"`pip install trl[vllm]` to use it."
)

if self.vllm_mode == "server":
Expand Down
Loading