Skip to content

Conversation

@colinzhaoxp
Copy link

@colinzhaoxp colinzhaoxp commented Nov 17, 2025

Add compute_metrics parameter for GRPOTrainer

This is my first open-source PR contribution, I would greatly appreciate any feedback on areas for improvement. Please don't hesitate to suggest changes - I'm eager to learn and make this contribution as good as possible!

What does this PR do?

This PR adds compute_metrics parameter for GRPOTrainer, which is already supported by Trainer. We can compute accuracy or downstream eval metrics over the evaluation dataset

Fixes related issues

#3729
#2959

Changes Made

Added compute_metrics parameter to GRPOTrainer
File: trl/trainer/grpo_trainer.py

Added a new optional parameter after num_generations:

from transformers.trainer_utils import seed_worker, EvalLoopOutput

class GRPOTrainer(BaseTrainer):
    """
    ...
    compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*):
    The function that will be used to compute metrics at evaluation. Must take a [`EvalPrediction`] and return
    a dictionary string to metric values. *Note* When passing TrainingArgs with `batch_eval_metrics` set to
    `True`, your compute_metrics function must take a boolean `compute_result` argument. This will be triggered
    after the last eval batch to signal that the function needs to calculate and return the global summary
    statistics rather than accumulating the batch-level statistics
    ...
    """
    def __init__(
        self,
        ...
        compute_metrics: Callable[[EvalLoopOutput], dict] | None = None,
        ...
    ):
    ...
    super().__init__(
        ...
        compute_metrics=compute_metrics,
        ...
    )

Example Usage

def my_eval_function(eval_predict):
       pass

trainer = GRPOTrainer(
    model="Qwen/Qwen2-0.5B-Instruct",
    reward_funcs=reward_func,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    compute_metrics=my_eval_function,
)
# trainer.train() # evaluation during training
trainer.evaluate() # or directly evaluate your model.

More examples are available in this blog

Benefits

  • Flexible: Users can choose their own function to evaluate their model during the trianing.

Who can review?

Any community member is welcome to provide feedback.
As my first open-source contribution, I'm excited to learn - please don't hesitate to suggest any enhancements!

@colinzhaoxp
Copy link
Author

colinzhaoxp commented Nov 17, 2025

@qgallouedec
Copy link
Member

thanks! I don't think it would work just like this though. Because in Trainer, compute_metrics is called only if compute_loss returns logits and labels. And in GRPO, it's not clear what would be the labels? Consequently, compute_loss doesn't support return_output

def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
if return_outputs:
raise ValueError("The GRPOTrainer does not support returning outputs")

You can run this code, and see that my_eval_function is never called:

from datasets import load_dataset
from trl import GRPOTrainer, GRPOConfig

dataset = load_dataset("trl-internal-testing/zen", "conversational_prompt_only")

def my_eval_function(eval_predict):
    print(eval_predict)


# Dummy reward function: count the number of unique characters in the completions
def reward_num_unique_chars(completions, **kwargs):
    return [len(set(c[0]["content"])) for c in completions]

trainer = GRPOTrainer(
    model="Qwen/Qwen3-0.6B",
    reward_funcs=reward_num_unique_chars,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    args=GRPOConfig(
        eval_steps=2,
        eval_strategy="steps",
    ),
    compute_metrics=my_eval_function,
)
trainer.train()

@colinzhaoxp
Copy link
Author

Thanks for your reply!

Yes, the demo above given by me is not the minimal runnable. I just want to explain the how to add compute_metric to GPROTrain. Just like your comments, if want to normally use this function, we need the the compute_loss returns logits and labels, and some other changes to fit.

So the next thing I should do is to give a minimal runnable demo? Because I think add other code, like change the code compute_loss, is not reasonable. I need your suggestions to improve this PR.

By the way, I will firstly provide a minimal runnable demo.

Thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants