diff --git a/docs/source/speeding_up_training.md b/docs/source/speeding_up_training.md index 6a3392aa6f9..57586295f8f 100644 --- a/docs/source/speeding_up_training.md +++ b/docs/source/speeding_up_training.md @@ -14,13 +14,7 @@ To speed up generation, you can use [vLLM](https://github.com/vllm-project/vllm) To use [vLLM](https://github.com/vllm-project/vllm), first install it using: ```bash -pip install vllm -``` - -or - -```bash -pip install "trl[vllm]" +pip install trl[vllm] ``` diff --git a/docs/source/vllm_integration.md b/docs/source/vllm_integration.md index 0cf92df944e..9240aed62ce 100644 --- a/docs/source/vllm_integration.md +++ b/docs/source/vllm_integration.md @@ -2,6 +2,12 @@ This document will guide you through the process of using vLLM with TRL for faster generation in online methods like GRPO and Online DPO. We first summarize a tl;dr on how to use vLLM with TRL, and then we will go into the details of how it works under the hood. Let's go! 🔥 + + +TRL currently only supports vLLM versions `0.10.0`, `0.10.1`, and `0.10.2`. Please ensure you have one of these versions installed to avoid compatibility issues. + + + ## 🚀 How can I use vLLM with TRL to speed up training? 💡 **Note**: Resources required for this specific example: a single node with 8 GPUs. diff --git a/examples/scripts/evals/judge_tldr.py b/examples/scripts/evals/judge_tldr.py index e803f335be8..286dfa1576f 100644 --- a/examples/scripts/evals/judge_tldr.py +++ b/examples/scripts/evals/judge_tldr.py @@ -14,8 +14,7 @@ # /// script # dependencies = [ -# "trl", -# "vllm", +# "trl[vllm]", # ] # /// diff --git a/examples/scripts/rloo.py b/examples/scripts/rloo.py index e9fb222f63b..bc599f7b9c0 100644 --- a/examples/scripts/rloo.py +++ b/examples/scripts/rloo.py @@ -14,12 +14,11 @@ # /// script # dependencies = [ -# "trl", +# "trl[vllm]", # "peft", # "math-verify", # "latex2sympy2_extended", # "trackio", -# "vllm", # "kernels", # ] # /// diff --git a/setup.cfg b/setup.cfg index fa431a70501..67f304f4088 100644 --- a/setup.cfg +++ b/setup.cfg @@ -64,7 +64,7 @@ test = vllm = # vLLM package does not yet support Python 3.13. These constraints can be lifted once support is added: # see https://github.com/vllm-project/vllm/pull/13164 - vllm>=0.10.0; python_version < "3.13" + vllm>=0.10.0,<=0.10.2; python_version < "3.13" fastapi; python_version < "3.13" pydantic; python_version < "3.13" requests; python_version < "3.13" diff --git a/trl/extras/vllm_client.py b/trl/extras/vllm_client.py index 177c5815ba5..0932697d6ee 100644 --- a/trl/extras/vllm_client.py +++ b/trl/extras/vllm_client.py @@ -114,7 +114,7 @@ def __init__( if not is_requests_available(): raise ImportError("requests is not installed. Please install it with `pip install requests`.") if not is_vllm_available(): - raise ImportError("vLLM is not installed. Please install it with `pip install vllm`.") + raise ImportError("vLLM is not installed. Please install it with `pip install trl[vllm]`.") self.session = requests.Session() diff --git a/trl/import_utils.py b/trl/import_utils.py index e495a845dae..0f15a17222c 100644 --- a/trl/import_utils.py +++ b/trl/import_utils.py @@ -14,6 +14,7 @@ import importlib import os +import warnings from itertools import chain from types import ModuleType from typing import Any @@ -35,7 +36,7 @@ _requests_available = _is_package_available("requests") _unsloth_available = _is_package_available("unsloth") _uvicorn_available = _is_package_available("uvicorn") -_vllm_available = _is_package_available("vllm") +_vllm_available, _vllm_version = _is_package_available("vllm", return_version=True) _vllm_ascend_available = _is_package_available("vllm_ascend") _weave_available = _is_package_available("weave") @@ -81,6 +82,15 @@ def is_uvicorn_available() -> bool: def is_vllm_available() -> bool: + if _vllm_available and ( + version.parse(_vllm_version) < version.parse("0.10.0") + or version.parse(_vllm_version) > version.parse("0.10.2") + ): + warnings.warn( + "TRL currently only supports vLLM versions `0.10.0`, `0.10.1`, and `0.10.2`. You have version " + f"{_vllm_version} installed. We recommend to install one of these versions to avoid compatibility issues.", + UserWarning, + ) return _vllm_available diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 69825102b27..2f220a229e3 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -479,7 +479,7 @@ def __init__( if not is_vllm_available(): raise ImportError( "vLLM is not available and `use_vllm` is set to True. Please install vLLM with " - "`pip install vllm` to use it." + "`pip install [vllm]` to use it." ) if self.vllm_mode == "server": diff --git a/trl/trainer/online_dpo_config.py b/trl/trainer/online_dpo_config.py index ae0eaace284..67dfa3b25f8 100644 --- a/trl/trainer/online_dpo_config.py +++ b/trl/trainer/online_dpo_config.py @@ -293,7 +293,7 @@ class may differ from those in [`~transformers.TrainingArguments`]. default=False, metadata={ "help": "Whether to use vLLM for generating completions. Requires vLLM to be installed " - "(`pip install vllm`)." + "(`pip install trl[vllm]`)." }, ) vllm_model_impl: str = field( diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index 7c9466a3dd9..7f3d32a26fa 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -483,7 +483,7 @@ def __init__( if not is_vllm_available(): raise ImportError( "vLLM is not available and `use_vllm` is set to True. Please install vLLM with " - "`pip install vllm` to use it." + "`pip install trl[vllm]` to use it." ) if self.vllm_mode == "server": diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 13ac5e34a3d..7e799f6952c 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -551,7 +551,7 @@ def decode(example, tokenizer): if not is_vllm_available(): raise ImportError( "vLLM is not available and `use_vllm` is set to True. Please install vLLM with " - "`pip install vllm` to use it." + "`pip install trl[vllm]` to use it." ) if self.vllm_mode == "server":