Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
5293d07
Introduce TP in coloc mode
toslali-ibm Apr 1, 2025
7e6245e
Introduce sleep mode in colocated vllms
toslali-ibm Apr 1, 2025
040b6e4
Fix process index bug
toslali-ibm Apr 1, 2025
3aa300a
Ignore prefix cache for now
toslali-ibm Apr 1, 2025
4e90aaa
Fix sleep issues
toslali-ibm Apr 1, 2025
85d9309
Fix tp slices
toslali-ibm Apr 1, 2025
d2bccd3
Fix typo in wake up
toslali-ibm Apr 1, 2025
d876fe4
Debugging memory
toslali-ibm Apr 2, 2025
712e85b
Measure memory during model update
toslali-ibm Apr 2, 2025
5127b97
Debug memory reserved
toslali-ibm Apr 2, 2025
a8a6f69
Remove prints
toslali-ibm Apr 2, 2025
5de5f3c
Validate via experiments
toslali-ibm Apr 2, 2025
cbb5ef3
Introduce flexible TP where TP may not be equal to world size
toslali-ibm Apr 7, 2025
c6b36e7
Remove prints
toslali-ibm Apr 7, 2025
cc3023c
Dont wakeup in pfix reset
toslali-ibm Apr 7, 2025
8ae9f01
Tp size should divide global world size evenly
toslali-ibm Apr 7, 2025
2de090f
Add max num seq
toslali-ibm Apr 8, 2025
8151e11
Bring back sleep
toslali-ibm Apr 8, 2025
946d49d
Fix sleep bug for grad accumulations
toslali-ibm Apr 8, 2025
fe8f684
Reload model during grad accumulation
toslali-ibm Apr 8, 2025
c3509de
Switch to sleep level 1
toslali-ibm Apr 8, 2025
aca6242
Sleep 1 during acc steps and levl 2 otherwise
toslali-ibm Apr 8, 2025
110cbce
Fix config dfefinition
toslali-ibm Apr 9, 2025
9e5128a
Debug generations
toslali-ibm Apr 9, 2025
675a1ed
Revert to sleep level 1 - as level 2 generates randomly
toslali-ibm Apr 9, 2025
91d7e72
Conduct 72b experiment
toslali-ibm Apr 9, 2025
65666eb
Remove prints
toslali-ibm Apr 9, 2025
710da69
Make sleep optional
toslali-ibm Apr 16, 2025
a426f59
Incorporate feedback
toslali-ibm Apr 16, 2025
0db5719
Merge branch 'coloc' into tpcoloc
toslali-ibm Apr 18, 2025
a1dd8e4
Incorporate Fabians comments
toslali-ibm Apr 18, 2025
2f95c00
Revert to sleep 2 and reload model during grad accumulation
toslali-ibm Apr 21, 2025
9c15044
Include accelerator in vllm client to access deepspeed
toslali-ibm Apr 21, 2025
3151828
Import deepspeed avaialble
toslali-ibm Apr 21, 2025
f44b0fe
Set seed in llm init
toslali-ibm Apr 21, 2025
0bb8102
Parametrize sleep level 2 and compare
toslali-ibm Apr 21, 2025
589ffc9
Debug level 1 and level 2
toslali-ibm Apr 21, 2025
2eb1855
Fix grad accumulation for sleep 2
toslali-ibm Apr 21, 2025
82d8d94
Fix grad accumulation for sleep 2
toslali-ibm Apr 21, 2025
e4fafee
Fix grad accumulation for sleep 2
toslali-ibm Apr 21, 2025
6191b89
Revert tpcoloc branch to commit a1dd8e4 without rewriting history
toslali-ibm Apr 21, 2025
717c8b4
Revert to sleep 2 after grad acc optimization fixing the model load e…
toslali-ibm Apr 21, 2025
4badf5b
Comparison of sleep levels
toslali-ibm Apr 22, 2025
a0c677a
Add max_num_batched_tokens needed for v1 profiling
toslali-ibm May 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 92 additions & 15 deletions trl/extras/vllm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import atexit
import logging
import os
import time
from typing import Optional

Expand All @@ -37,8 +38,7 @@
from vllm_ascend.distributed.device_communicators.pyhccl import PyHcclCommunicator as PyNcclCommunicator

from accelerate import Accelerator
from accelerate.utils import broadcast_object_list, gather_object

from accelerate.utils import broadcast_object_list, gather_object, set_seed

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -351,22 +351,47 @@ def close_communicator(self):

class VLLMColocationClient:
"""
A client class to interact with vLLM processes colocated with the training process.
A client class for interacting with vLLM models colocated with the training process.

This client bypasses remote communication and directly interacts with the in-process vLLM engine.
It supports weight updates and text generation functionalities similar to `VLLMClient`, but is optimized
for scenarios where vLLM is running in the same process or node as training.
This client eliminates remote communication overhead by directly interfacing with the in-process vLLM engine.
It supports weight updates and text generation, and is optimized for tensor-parallel setups where multiple
ranks share a single vLLM engine per node or process group.

Args:
args (`GRPOConfig`): Configuration object containing vLLM parameters.
model (`transformers.PreTrainedModel`): The model being used.
vllm_device (`torch.device` or `str`): Device on which the model is loaded (e.g., "cuda:0").
args (GRPOConfig): Configuration object with vLLM-specific parameters.
model (transformers.PreTrainedModel): The model used for generation and weight updates.
device (str): Device where the model is loaded.
num_processes (int): Total number of distributed processes (world size).
process_index (int): Index of the current process in the distributed setup.
"""

def __init__(self, args: GRPOConfig, model, vllm_device):
self.args: GRPOConfig = args
def __init__(self, args: GRPOConfig, model, device, num_processes, process_index):
self.args = args
self.model = model
self.vllm_device = vllm_device
self.vllm_device = device
self.world_size = num_processes
self.process_index = process_index
self._is_sleeping = False
set_seed(42)

# Ensure TP value is valid (at least 1)
assert self.args.vllm_colocation >= 1, "vllm_colocation must be greater than 0"

# Make sure TP group size evenly divides the world size
# This ensures each group has the same number of ranks
assert self.world_size % self.args.vllm_colocation == 0, (
f"TP size of vllm_colocation ({self.args.vllm_colocation}) must divide world size "
f"({self.world_size}) evenly."
)

if self.args.vllm_colocation > 1: # if model is sharded, create subgroups
# Create subgroups of ranks for TP, each group with `vllm_colocation` ranks.
# For example, if world_size=8 and vllm_colocation=2 → groups: [0,1], [2,3], [4,5], [6,7]
self.tp_group, _ = torch.distributed.new_subgroups_by_enumeration(
[
list(range(i*self.args.vllm_colocation, (i+1) * self.args.vllm_colocation))
for i in range(self.world_size // self.args.vllm_colocation)
]
)

self.llm = LLM(
model=self.model.name_or_path,
Expand All @@ -375,8 +400,39 @@ def __init__(self, args: GRPOConfig, model, vllm_device):
dtype=self.args.vllm_dtype,
enable_prefix_caching=self.args.vllm_enable_prefix_caching,
max_model_len=self.args.vllm_max_model_len,
max_num_batched_tokens=self.args.vllm_max_model_len,
tensor_parallel_size=args.vllm_colocation,
distributed_executor_backend="external_launcher",
enable_sleep_mode=self.args.vllm_sleep_enabled,
max_num_seqs=self.args.per_device_train_batch_size * self.args.vllm_colocation,
seed=int(os.getenv("RANK", "0")) // self.args.vllm_colocation, # feed identical seed for tp groups
)

def maybe_wake_up_vllm(self):
"""
Wakes up the vLLM engine if it is currently in sleep mode.

This is useful before any generation or weight update calls to ensure the model is active.
It also calls `torch.cuda.empty_cache()` to free unused memory, helping avoid OOM errors.
"""
torch.cuda.empty_cache()
if self.args.vllm_sleep_enabled and self._is_sleeping:
self.llm.wake_up()
self._is_sleeping = False

def maybe_sleep_vllm(self):
"""
Puts the vLLM engine into sleep mode after generation is complete.

This helps conserve memory by offloading cached resources.
The sleep only happens if `vllm_sleep_enabled` is set to True in the config.
"""
if self.args.vllm_sleep_enabled:
if self.args.vllm_sleep_level1:
self.llm.sleep(level=1)
else:
self.llm.sleep(level=2)
self._is_sleeping = True

def update_named_param(self, name: str, weights: torch.Tensor):
"""
Expand All @@ -388,6 +444,7 @@ def update_named_param(self, name: str, weights: torch.Tensor):
weights (`torch.Tensor`):
Tensor containing the updated weights.
"""
self.maybe_wake_up_vllm()
llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
llm_model.load_weights([(name,weights)])

Expand Down Expand Up @@ -430,14 +487,24 @@ def generate(
`list[list[int]]`:
List of lists of token IDs representing the model-generated completions for each prompt.
"""
self.maybe_wake_up_vllm()

# Guided decoding, if enabled
if guided_decoding_regex is not None:
guided_decoding = GuidedDecodingParams(backend="outlines", regex=guided_decoding_regex)
else:
guided_decoding = None

if self.args.vllm_colocation > 1:
# Gather prompts from all ranks in the TP group and flatten.
# Each rank starts with its own prompts; after gathering, all ranks see the full group set.
orig_size = len(prompts)
gathered_prompts = [None for _ in range(self.args.vllm_colocation)]
torch.distributed.all_gather_object(gathered_prompts, prompts, group=self.tp_group)
prompts = [p for sublist in gathered_prompts for p in sublist]

sampling_params = SamplingParams(
n=1, # vLLM on each GPU generates only 1 in vllm_colocation mode
n=1, # vLLM on each device generates only 1 in vllm_colocation mode
repetition_penalty=repetition_penalty,
temperature=temperature,
top_p=top_p,
Expand All @@ -450,7 +517,17 @@ def generate(
all_outputs = self.llm.generate(
prompts, sampling_params=sampling_params, use_tqdm=False
)

completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs]

if self.args.vllm_colocation > 1:
# Slice completions for this rank within its TP group.
# Each rank generates all outputs — we keep only our share.
local_rank_in_group = torch.distributed.get_rank(group=self.tp_group)
tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size)
completion_ids = completion_ids[tp_slice]

self.maybe_sleep_vllm()
return completion_ids

def reset_prefix_cache(self):
Expand All @@ -477,7 +554,7 @@ def get_vllm_client(args: GRPOConfig, model, accelerator: Accelerator) -> VLLMNo
accelerator (`Accelerator`): Hugging Face `Accelerator` object that helps with multi-GPU training.
"""
if args.vllm_colocation:
return VLLMColocationClient(args, model, accelerator.device)
return VLLMColocationClient(args, model, accelerator.device, accelerator.num_processes, accelerator.process_index)
elif accelerator.is_main_process:
return VLLMClient(
args.vllm_server_host, args.vllm_server_port, connection_timeout=args.vllm_server_timeout,
Expand Down
41 changes: 33 additions & 8 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,14 @@ class GRPOConfig(TrainingArguments):
timeout, a `ConnectionError` is raised.
vllm_guided_decoding_regex (`str` or `None`, *optional*, defaults to `None`):
Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled.
vllm_colocation (`bool`, *optional*, defaults to `False`):
Whether to use colocated vLLM execution via external launcher. If set to `True`, vLLM will be
initialized in **all processes**, each assigned to its respective device. This allows multi-GPU
or multi-node execution with vLLM's external launcher, enabling improved large-scale inference.
vllm_colocation (`int` or `None`, *optional*, defaults to `None`):
Controls colocated vLLM execution and tensor parallelism via the `external_launcher` backend.
- Set to `None` to disable colocated vLLM entirely.
- Set to `1` to enable colocated vLLM on each GPU with no tensor parallelism.
- Set to a value >1 to enable colocated vLLM with tensor parallelism across multiple GPUs.
vllm_sleep_enabled (`bool`, *optional*, defaults to `False`):
Indicates whether to enable the sleep operation for vLLM during training.
If set to `True`, vLLM will remain in sleep mode throughout the training stage.

> Parameters that control the training

Expand Down Expand Up @@ -293,12 +297,33 @@ class GRPOConfig(TrainingArguments):
default=None,
metadata={"help": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."},
)
vllm_colocation: Optional[bool] = field(
vllm_colocation: Optional[int] = field(
default=None,
metadata={
"help": (
"Controls colocated vLLM execution and tensor parallelism using the `external_launcher` backend. "
"Set to `None` to disable colocated vLLM. "
"Set to `1` to enable colocated vLLM on each device (no tensor parallelism). "
"Set to a value >1 to enable colocated vLLM with tensor parallelism across multiple devices."
)
},
)
vllm_sleep_enabled: Optional[bool] = field(
default=False,
metadata={
"help": "Whether to use colocated vLLM execution via external launcher. If set to `True`, vLLM will be "
"initialized in all processes, each assigned to its respective device. This enables optimized "
"multi-GPU inference."
"help": (
"Enables sleep mode for colocated vLLM during training. "
"Set to `True` to keep vLLM in sleep state during training steps, helping reduce memory usage. "
"Set to `False` to disable this behavior."
)
},
)
vllm_sleep_level1: Optional[bool] = field(
default=False,
metadata={
"help": (
"Sleep level 1 enabled - otherwise sleep level 2 default"
)
},
)

Expand Down