diff --git a/trl/extras/vllm_client.py b/trl/extras/vllm_client.py index 90a33ae87b9..0b5e364eaa2 100644 --- a/trl/extras/vllm_client.py +++ b/trl/extras/vllm_client.py @@ -14,6 +14,7 @@ import atexit import logging +import os import time from typing import Optional @@ -37,8 +38,7 @@ from vllm_ascend.distributed.device_communicators.pyhccl import PyHcclCommunicator as PyNcclCommunicator from accelerate import Accelerator -from accelerate.utils import broadcast_object_list, gather_object - +from accelerate.utils import broadcast_object_list, gather_object, set_seed logger = logging.getLogger(__name__) @@ -351,22 +351,47 @@ def close_communicator(self): class VLLMColocationClient: """ - A client class to interact with vLLM processes colocated with the training process. + A client class for interacting with vLLM models colocated with the training process. - This client bypasses remote communication and directly interacts with the in-process vLLM engine. - It supports weight updates and text generation functionalities similar to `VLLMClient`, but is optimized - for scenarios where vLLM is running in the same process or node as training. + This client eliminates remote communication overhead by directly interfacing with the in-process vLLM engine. + It supports weight updates and text generation, and is optimized for tensor-parallel setups where multiple + ranks share a single vLLM engine per node or process group. Args: - args (`GRPOConfig`): Configuration object containing vLLM parameters. - model (`transformers.PreTrainedModel`): The model being used. - vllm_device (`torch.device` or `str`): Device on which the model is loaded (e.g., "cuda:0"). + args (GRPOConfig): Configuration object with vLLM-specific parameters. + model (transformers.PreTrainedModel): The model used for generation and weight updates. + device (str): Device where the model is loaded. + num_processes (int): Total number of distributed processes (world size). + process_index (int): Index of the current process in the distributed setup. """ - - def __init__(self, args: GRPOConfig, model, vllm_device): - self.args: GRPOConfig = args + def __init__(self, args: GRPOConfig, model, device, num_processes, process_index): + self.args = args self.model = model - self.vllm_device = vllm_device + self.vllm_device = device + self.world_size = num_processes + self.process_index = process_index + self._is_sleeping = False + set_seed(42) + + # Ensure TP value is valid (at least 1) + assert self.args.vllm_colocation >= 1, "vllm_colocation must be greater than 0" + + # Make sure TP group size evenly divides the world size + # This ensures each group has the same number of ranks + assert self.world_size % self.args.vllm_colocation == 0, ( + f"TP size of vllm_colocation ({self.args.vllm_colocation}) must divide world size " + f"({self.world_size}) evenly." + ) + + if self.args.vllm_colocation > 1: # if model is sharded, create subgroups + # Create subgroups of ranks for TP, each group with `vllm_colocation` ranks. + # For example, if world_size=8 and vllm_colocation=2 → groups: [0,1], [2,3], [4,5], [6,7] + self.tp_group, _ = torch.distributed.new_subgroups_by_enumeration( + [ + list(range(i*self.args.vllm_colocation, (i+1) * self.args.vllm_colocation)) + for i in range(self.world_size // self.args.vllm_colocation) + ] + ) self.llm = LLM( model=self.model.name_or_path, @@ -375,8 +400,39 @@ def __init__(self, args: GRPOConfig, model, vllm_device): dtype=self.args.vllm_dtype, enable_prefix_caching=self.args.vllm_enable_prefix_caching, max_model_len=self.args.vllm_max_model_len, + max_num_batched_tokens=self.args.vllm_max_model_len, + tensor_parallel_size=args.vllm_colocation, distributed_executor_backend="external_launcher", + enable_sleep_mode=self.args.vllm_sleep_enabled, + max_num_seqs=self.args.per_device_train_batch_size * self.args.vllm_colocation, + seed=int(os.getenv("RANK", "0")) // self.args.vllm_colocation, # feed identical seed for tp groups ) + + def maybe_wake_up_vllm(self): + """ + Wakes up the vLLM engine if it is currently in sleep mode. + + This is useful before any generation or weight update calls to ensure the model is active. + It also calls `torch.cuda.empty_cache()` to free unused memory, helping avoid OOM errors. + """ + torch.cuda.empty_cache() + if self.args.vllm_sleep_enabled and self._is_sleeping: + self.llm.wake_up() + self._is_sleeping = False + + def maybe_sleep_vllm(self): + """ + Puts the vLLM engine into sleep mode after generation is complete. + + This helps conserve memory by offloading cached resources. + The sleep only happens if `vllm_sleep_enabled` is set to True in the config. + """ + if self.args.vllm_sleep_enabled: + if self.args.vllm_sleep_level1: + self.llm.sleep(level=1) + else: + self.llm.sleep(level=2) + self._is_sleeping = True def update_named_param(self, name: str, weights: torch.Tensor): """ @@ -388,6 +444,7 @@ def update_named_param(self, name: str, weights: torch.Tensor): weights (`torch.Tensor`): Tensor containing the updated weights. """ + self.maybe_wake_up_vllm() llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model llm_model.load_weights([(name,weights)]) @@ -430,14 +487,24 @@ def generate( `list[list[int]]`: List of lists of token IDs representing the model-generated completions for each prompt. """ + self.maybe_wake_up_vllm() + # Guided decoding, if enabled if guided_decoding_regex is not None: guided_decoding = GuidedDecodingParams(backend="outlines", regex=guided_decoding_regex) else: guided_decoding = None + if self.args.vllm_colocation > 1: + # Gather prompts from all ranks in the TP group and flatten. + # Each rank starts with its own prompts; after gathering, all ranks see the full group set. + orig_size = len(prompts) + gathered_prompts = [None for _ in range(self.args.vllm_colocation)] + torch.distributed.all_gather_object(gathered_prompts, prompts, group=self.tp_group) + prompts = [p for sublist in gathered_prompts for p in sublist] + sampling_params = SamplingParams( - n=1, # vLLM on each GPU generates only 1 in vllm_colocation mode + n=1, # vLLM on each device generates only 1 in vllm_colocation mode repetition_penalty=repetition_penalty, temperature=temperature, top_p=top_p, @@ -450,7 +517,17 @@ def generate( all_outputs = self.llm.generate( prompts, sampling_params=sampling_params, use_tqdm=False ) + completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs] + + if self.args.vllm_colocation > 1: + # Slice completions for this rank within its TP group. + # Each rank generates all outputs — we keep only our share. + local_rank_in_group = torch.distributed.get_rank(group=self.tp_group) + tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size) + completion_ids = completion_ids[tp_slice] + + self.maybe_sleep_vllm() return completion_ids def reset_prefix_cache(self): @@ -477,7 +554,7 @@ def get_vllm_client(args: GRPOConfig, model, accelerator: Accelerator) -> VLLMNo accelerator (`Accelerator`): Hugging Face `Accelerator` object that helps with multi-GPU training. """ if args.vllm_colocation: - return VLLMColocationClient(args, model, accelerator.device) + return VLLMColocationClient(args, model, accelerator.device, accelerator.num_processes, accelerator.process_index) elif accelerator.is_main_process: return VLLMClient( args.vllm_server_host, args.vllm_server_port, connection_timeout=args.vllm_server_timeout, diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index f31f392846c..4d9751e2878 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -94,10 +94,14 @@ class GRPOConfig(TrainingArguments): timeout, a `ConnectionError` is raised. vllm_guided_decoding_regex (`str` or `None`, *optional*, defaults to `None`): Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled. - vllm_colocation (`bool`, *optional*, defaults to `False`): - Whether to use colocated vLLM execution via external launcher. If set to `True`, vLLM will be - initialized in **all processes**, each assigned to its respective device. This allows multi-GPU - or multi-node execution with vLLM's external launcher, enabling improved large-scale inference. + vllm_colocation (`int` or `None`, *optional*, defaults to `None`): + Controls colocated vLLM execution and tensor parallelism via the `external_launcher` backend. + - Set to `None` to disable colocated vLLM entirely. + - Set to `1` to enable colocated vLLM on each GPU with no tensor parallelism. + - Set to a value >1 to enable colocated vLLM with tensor parallelism across multiple GPUs. + vllm_sleep_enabled (`bool`, *optional*, defaults to `False`): + Indicates whether to enable the sleep operation for vLLM during training. + If set to `True`, vLLM will remain in sleep mode throughout the training stage. > Parameters that control the training @@ -293,12 +297,33 @@ class GRPOConfig(TrainingArguments): default=None, metadata={"help": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."}, ) - vllm_colocation: Optional[bool] = field( + vllm_colocation: Optional[int] = field( + default=None, + metadata={ + "help": ( + "Controls colocated vLLM execution and tensor parallelism using the `external_launcher` backend. " + "Set to `None` to disable colocated vLLM. " + "Set to `1` to enable colocated vLLM on each device (no tensor parallelism). " + "Set to a value >1 to enable colocated vLLM with tensor parallelism across multiple devices." + ) + }, + ) + vllm_sleep_enabled: Optional[bool] = field( default=False, metadata={ - "help": "Whether to use colocated vLLM execution via external launcher. If set to `True`, vLLM will be " - "initialized in all processes, each assigned to its respective device. This enables optimized " - "multi-GPU inference." + "help": ( + "Enables sleep mode for colocated vLLM during training. " + "Set to `True` to keep vLLM in sleep state during training steps, helping reduce memory usage. " + "Set to `False` to disable this behavior." + ) + }, + ) + vllm_sleep_level1: Optional[bool] = field( + default=False, + metadata={ + "help": ( + "Sleep level 1 enabled - otherwise sleep level 2 default" + ) }, )