Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,6 @@
title: GRPO
- local: kto_trainer
title: KTO
- local: prm_trainer
title: PRM
- local: reward_trainer
title: Reward
- local: rloo_trainer
Expand Down Expand Up @@ -119,6 +117,8 @@
title: PAPO
- local: ppo_trainer
title: PPO
- local: prm_trainer
title: PRM
- local: xpo_trainer
title: XPO
- local: openenv
Expand Down
2 changes: 1 addition & 1 deletion docs/source/dataset_formats.md
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,6 @@ Choosing the right dataset type depends on the task you are working on and the s
| [`GRPOTrainer`] | [Prompt-only](#prompt-only) |
| [`KTOTrainer`] | [Unpaired preference](#unpaired-preference) or [Preference (explicit prompt recommended)](#preference) |
| [`OnlineDPOTrainer`] | [Prompt-only](#prompt-only) |
| [`PRMTrainer`] | [Stepwise supervision](#stepwise-supervision) |
| [`RewardTrainer`] | [Preference (implicit prompt recommended)](#preference) |
| [`RLOOTrainer`] | [Prompt-only](#prompt-only) |
| [`SFTTrainer`] | [Language modeling](#language-modeling) or [Prompt-completion](#prompt-completion) |
Expand All @@ -401,6 +400,7 @@ Choosing the right dataset type depends on the task you are working on and the s
| [`experimental.nash_md.NashMDTrainer`] | [Prompt-only](#prompt-only) |
| [`experimental.orpo.ORPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
| [`experimental.ppo.PPOTrainer`] | Tokenized language modeling |
| [`experimental.prm.PRMTrainer`] | [Stepwise supervision](#stepwise-supervision) |
| [`experimental.xpo.XPOTrainer`] | [Prompt-only](#prompt-only) |

## Using any dataset with TRL: preprocessing and conversion
Expand Down
2 changes: 1 addition & 1 deletion docs/source/example_overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ Scripts are maintained in the [`trl/scripts`](https://github.com/huggingface/trl
| [`examples/scripts/orpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/orpo.py) | This script shows how to use the [`experimental.orpo.ORPOTrainer`] to fine-tune a model to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. |
| [`examples/scripts/ppo/ppo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo/ppo.py) | This script shows how to use the [`experimental.ppo.PPOTrainer`] to fine-tune a model to improve its ability to continue text with positive sentiment or physically descriptive language. |
| [`examples/scripts/ppo/ppo_tldr.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo/ppo_tldr.py) | This script shows how to use the [`experimental.ppo.PPOTrainer`] to fine-tune a model to improve its ability to generate TL;DR summaries. |
| [`examples/scripts/prm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/prm.py) | This script shows how to use the [`PRMTrainer`] to fine-tune a Process-supervised Reward Model (PRM). |
| [`examples/scripts/prm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/prm.py) | This script shows how to use the [`experimental.prm.PRMTrainer`] to fine-tune a Process-supervised Reward Model (PRM). |
| [`examples/scripts/reward_modeling.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/reward_modeling.py) | This script shows how to use the [`RewardTrainer`] to train an Outcome Reward Model (ORM) on your own dataset. |
| [`examples/scripts/rloo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/rloo.py) | This script shows how to use the [`RLOOTrainer`] to fine-tune a model to improve its ability to solve math questions. |
| [`examples/scripts/sft.py`](https://github.com/huggingface/trl/blob/main/trl/scripts/sft.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a model. |
Expand Down
2 changes: 1 addition & 1 deletion docs/source/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ Below is the current list of TRL trainers, organized by method type (⚡️ = vL

### Reward modeling

- [`PRMTrainer`]
- [`RewardTrainer`]
- [`experimental.prm.PRMTrainer`] 🧪

</div>
<div style="flex: 1; min-width: 0;">
Expand Down
9 changes: 6 additions & 3 deletions docs/source/prm_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

[![model badge](https://img.shields.io/badge/All_models-PRM-blue)](https://huggingface.co/models?other=prm,trl)

> [!TIP]
> PRMTrainer has been moved to `trl.experimental.prm.PRMTrainer`. The `trl.trainer` version is deprecated and will be removed in TRL 0.29.0. Please update your imports to use `trl.experimental.prm.PRMTrainer` instead. See [issue #4467](https://github.com/huggingface/trl/issues/4467) for more information.

> [!WARNING]
> PRM Trainer is an experimental API which is subject to change at any time.

Expand Down Expand Up @@ -31,7 +34,7 @@ Below is the script to train the model:
```python
# train_prm.py
from datasets import load_dataset
from trl import PRMConfig, PRMTrainer
from trl.experimental.prm import PRMConfig, PRMTrainer
from transformers import AutoModelForTokenClassification, AutoTokenizer

model = AutoModelForTokenClassification.from_pretrained("Qwen/Qwen2-0.5B", num_labels=2)
Expand Down Expand Up @@ -112,11 +115,11 @@ accelerate launch examples/scripts/prm.py \

## PRMTrainer

[[autodoc]] PRMTrainer
[[autodoc]] experimental.prm.PRMTrainer
- train
- save_model
- push_to_hub

## PRMConfig

[[autodoc]] PRMConfig
[[autodoc]] experimental.prm.PRMConfig
3 changes: 1 addition & 2 deletions examples/scripts/prm.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,12 @@

from trl import (
ModelConfig,
PRMConfig,
PRMTrainer,
ScriptArguments,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
)
from trl.experimental.prm import PRMConfig, PRMTrainer


logger = logging.get_logger(__name__)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
from transformers import AutoModelForTokenClassification, AutoTokenizer, PreTrainedTokenizerBase
from transformers.utils import is_peft_available

from trl import PRMConfig, PRMTrainer
from trl.experimental.prm import PRMConfig, PRMTrainer

from .testing_utils import TrlTestCase, require_peft
from ..testing_utils import TrlTestCase, require_peft


if is_peft_available():
Expand Down
19 changes: 19 additions & 0 deletions trl/experimental/prm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .prm_config import PRMConfig
from .prm_trainer import PRMTrainer


__all__ = ["PRMConfig", "PRMTrainer"]
112 changes: 112 additions & 0 deletions trl/experimental/prm/prm_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass, field

from transformers import TrainingArguments


@dataclass
class PRMConfig(TrainingArguments):
r"""
Configuration class for the [`experimental.prm.PRMTrainer`].

This class includes only the parameters that are specific to PRM training. For a full list of training arguments,
please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may
differ from those in [`~transformers.TrainingArguments`].

Using [`~transformers.HfArgumentParser`] we can turn this class into
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
command line.

Parameters:
max_length (`int` or `None`, *optional*, defaults to `1024`):
Maximum length of the sequences (prompt + completion) used for truncation.
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
Maximum length of the prompt used for truncation.
max_completion_length (`int`, *optional*):
Maximum length of the completion used for truncation. The completion is the concatenation of the steps.
disable_dropout (`bool`, *optional*, defaults to `True`):
Whether to disable dropout in the model.
step_separator (`str`, *optional*, defaults to `"\n"`):
Separator used to separate each step of the reasoning process.
train_on_last_step_only (`bool`, *optional*, defaults to `False`):
Whether to train only on the last step.
dataset_num_proc (`int`, *optional*):
Number of processes to use for processing the dataset.
"""

# Parameters whose default values are overridden from TrainingArguments
learning_rate: float = field(
default=1e-5,
metadata={"help": "The initial learning rate for AdamW."},
)
logging_steps: float = field(
default=10,
metadata={
"help": "Log every X updates steps. Should be an integer or a float in range `[0,1)`. If smaller than 1, "
"will be interpreted as ratio of total training steps."
},
)
gradient_checkpointing: bool = field(
default=True,
metadata={
"help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass."
},
)
bf16: bool | None = field(
default=None,
metadata={
"help": "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA "
"architecture or Intel XPU or using CPU (use_cpu) or Ascend NPU. If not set, it defaults to `True` if "
"`fp16` is not set."
},
)

max_length: int | None = field(
default=1024,
metadata={"help": "Maximum length of the sequences (prompt + completion) used for truncation."},
)
max_prompt_length: int | None = field(
default=512,
metadata={"help": "Maximum length of the prompt used for truncation."},
)
max_completion_length: int | None = field(
default=None,
metadata={
"help": "Maximum length of the completion used for truncation. The completion is the concatenation of the "
"steps."
},
)
disable_dropout: bool = field(
default=True,
metadata={"help": "Whether to disable dropout in the model and reference model."},
)
step_separator: str = field(
default="\n",
metadata={"help": "Separator used to separate each step of the reasoning process."},
)
train_on_last_step_only: bool = field(
default=False,
metadata={"help": "Whether to train only on the last step."},
)
dataset_num_proc: int | None = field(
default=None,
metadata={"help": "Number of processes to use for processing the dataset."},
)

def __post_init__(self):
self.bf16 = not (self.fp16) if self.bf16 is None else self.bf16

super().__post_init__()
Loading
Loading