Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
is_trackio_available,
is_wandb_available,
)
from transformers.trainer_utils import seed_worker
from transformers.trainer_utils import seed_worker, EvalLoopOutput
from transformers.utils import is_datasets_available, is_peft_available, is_rich_available

from ..data_utils import (
Expand Down Expand Up @@ -190,6 +190,12 @@ class GRPOTrainer(BaseTrainer):
processing class is loaded from the model's name with [`~transformers.AutoProcessor.from_pretrained`]. A
padding token, `tokenizer.pad_token`, must be set. If the processing class has not set a padding token,
`tokenizer.eos_token` will be used as the default.
compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*):
The function that will be used to compute metrics at evaluation. Must take a [`EvalPrediction`] and return
a dictionary string to metric values. *Note* When passing TrainingArgs with `batch_eval_metrics` set to
`True`, your compute_metrics function must take a boolean `compute_result` argument. This will be triggered
after the last eval batch to signal that the function needs to calculate and return the global summary
statistics rather than accumulating the batch-level statistics
reward_processing_classes ([`~transformers.PreTrainedTokenizerBase`] or `list[PreTrainedTokenizerBase]`, *optional*):
Processing classes corresponding to the reward functions specified in `reward_funcs`. Can be either:

Expand Down Expand Up @@ -242,6 +248,7 @@ def __init__(
train_dataset: Dataset | IterableDataset | None = None,
eval_dataset: Dataset | IterableDataset | dict[str, Dataset | IterableDataset] | None = None,
processing_class: PreTrainedTokenizerBase | ProcessorMixin | None = None,
compute_metrics: Callable[[EvalLoopOutput], dict] | None = None,
reward_processing_classes: PreTrainedTokenizerBase | list[PreTrainedTokenizerBase] | None = None,
callbacks: list[TrainerCallback] | None = None,
optimizers: tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None] = (None, None),
Expand Down Expand Up @@ -450,6 +457,7 @@ def __init__(
processing_class=processing_class,
callbacks=callbacks,
optimizers=optimizers,
compute_metrics=compute_metrics,
# In Trainer, `training_step` scales the loss by `gradient_accumulation_steps` only if `compute_loss_func`
# is None. For DAPO, loss scaling instead depends on the total number of completions tokens across the
# global accumulated batch. To control scaling ourselves, we must disable Trainer’s built-in scaling. The
Expand Down