Skip to content

Conversation

@mingxuetian
Copy link

@mingxuetian mingxuetian commented Nov 5, 2025

Add num_generations_eval parameter for efficient evaluation

This is my first open-source PR contribution, I would greatly appreciate any feedback on areas for improvement. Please don't hesitate to suggest changes - I'm eager to learn and make this contribution as good as possible!

What does this PR do?

This PR adds support for using a different number of generations during evaluation compared to training in GRPOTrainer. This allows users to save computation time during evaluation while maintaining training quality.

Fixes

Fix #3539 #3566

Motivation

During training, multiple generations per prompt are often needed for better exploration and diversity. However, during evaluation, fewer generations are typically sufficient to assess model performance. This feature enables more efficient evaluation without compromising training effectiveness.

For example, users can train with 16 generations per prompt but evaluate with only 2 generations, reducing evaluation time by 8x.

Changes Made

1. Added num_generations_eval parameter to GRPOConfig

File: trl/trainer/grpo_config.py

Added a new optional parameter after num_generations:

num_generations_eval: int | None = field(
    default=None,
    metadata={
        "help": "Number of generations to sample during evaluation. If `None`, uses the value of "
        "`num_generations`. This allows using fewer generations during evaluation to save computation. "
        "Maintains backward compatibility with previous configuration files."
    },
)

2. Modified GRPOTrainer.__init__ to store the parameter

File: trl/trainer/grpo_trainer.py

Added line 383 to store the new parameter:

self.num_generations = args.num_generations  # = G in the GRPO paper
self.num_generations_eval = args.num_generations_eval  # NEW LINE
self.chat_template_kwargs = args.chat_template_kwargs or {}

3. Updated _get_eval_sampler method

File: trl/trainer/grpo_trainer.py

Modified the eval sampler to use num_generations_eval when available:

def _get_eval_sampler(self, eval_dataset) -> Sampler:
    # See _get_train_sampler for an explanation of the sampler.
    # If None, use num_generations for backward compatibility with previous config files
    num_gens = self.num_generations_eval or self.num_generations
    return RepeatSampler(
        data_source=eval_dataset,
        mini_repeat_count=num_gens,
        seed=self.args.seed,
    )

4. Updated vLLM server mode generation logic

File: trl/trainer/grpo_trainer.py (lines 1166-1173)

Modified to dynamically select the correct number of generations based on mode:

# Determine num_generations based on mode
mode = "train" if self.model.training else "eval"
num_gens = (
    self.num_generations_eval
    if mode == "eval" and self.num_generations_eval is not None
    else self.num_generations
)
ordered_set_of_prompts = all_prompts[::num_gens]

5. Updated prompt repetition logic in server mode

File: trl/trainer/grpo_trainer.py (lines 1223-1231)

Modified to repeat prompts the correct number of times:

# Determine repeat count based on mode
mode = "train" if self.model.training else "eval"
num_gens = (
    self.num_generations_eval
    if mode == "eval" and self.num_generations_eval is not None
    else self.num_generations
)
# At this point, we only get 1 copy of each prompt, so we need to repeat them num_generations times
all_prompt_ids = [ids for ids in all_prompt_ids for _ in range(num_gens)]

6. Updated reward computation logic

File: trl/trainer/grpo_trainer.py (lines 1616-1621)

Modified to handle different generation counts for train/eval modes:

# If None, use num_generations for backward compatibility with previous config files
# Determine num_generations based on mode before computing grouped-wise rewards
mode = "train" if self.model.training else "eval"
num_gens = self.num_generations_eval or self.num_generations if mode == "eval" else self.num_generations
# Compute grouped-wise rewards
mean_grouped_rewards = rewards.view(-1, num_gens).mean(dim=1)

Summary of Modified Files

  1. trl/trainer/grpo_config.py: Added num_generations_eval parameter definition
  2. trl/trainer/grpo_trainer.py: Modified 4 locations:
    • Line 383: Store the parameter in __init__
    • Lines 760-768: Updated _get_eval_sampler method
    • Lines 1166-1173: Updated vLLM server mode generation
    • Lines 1223-1231: Updated prompt repetition logic
    • Lines 1616-1621: Updated reward computation

Backward Compatibility

Fully backward compatible: When num_generations_eval is None (default), the trainer falls back to using num_generations, ensuring existing configurations work without any changes.

Example Usage

args = GRPOConfig(
    num_generations=8,        # Use 8 generations during training
    num_generations_eval=2,   # Use only 2 generations during evaluation (4x faster)
    ...
)

trainer = GRPOTrainer(
    model="Qwen/Qwen2-0.5B-Instruct",
    reward_funcs=reward_func,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
)

trainer.train()

Benefits

  • Faster evaluation: Reduce evaluation time by using fewer generations
  • Cost savings: Lower computational costs during evaluation
  • Maintained quality: Training quality remains unchanged
  • Flexible: Users can choose the optimal trade-off between speed and evaluation accuracy

Who can review?

This PR is ready for review! Any community member is welcome to provide feedback.
A special thanks to @qgallouedec for considering this PR.
As my first open-source contribution, I'm excited to learn - please don't hesitate to suggest any enhancements!

Copy link
Member

@qgallouedec qgallouedec left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your contribution, and welcome to the open-source community!
Regarding the PR — I don’t see a scenario where having a different number of generations during evaluation and training would be necessary. I’ll leave the PR open for now to see if the community expresses interest in this feature. If not, we can close it later.

@mingxuetian
Copy link
Author

mingxuetian commented Nov 6, 2025

Thanks for your contribution, and welcome to the open-source community! Regarding the PR — I don’t see a scenario where having a different number of generations during evaluation and training would be necessary. I’ll leave the PR open for now to see if the community expresses interest in this feature. If not, we can close it later.

Thank you @qgallouedec for the review!
I'd like to explain the practical motivation based on my training experience and issues #3539 @SnorkelerVigi and #3566 @CasanovaLLL
:

Community Need

Both issues #3539 and #3566 specifically request this feature because evaluation overhead was a major bottleneck in their training pipelines. @qgallouedec You also mentioned in your replies to these issues that this problem needs to be addressed, which confirms the necessity of this feature.

Why Different num_generations?

During Training:

  • Large num_generations (e.g., 16) is essential for accurate advantage estimation via group-wise reward normalization: mean_grouped_rewards = rewards.view(-1, num_gens).mean(dim=1)
  • More samples per prompt → more stable advantages → better training quality

During Evaluation:

  • We only need to monitor model performance, not compute gradients
  • Evaluation metrics are reliable with far fewer generations (e.g., 2)
  • Using training's large num_generations significantly slows down training

Real Impact

From my experiments:

  • Setup: num_generations=16 (train), num_generations_eval=2 (eval)
  • Result: ~87.5% faster evaluation (8x → 1x time)
  • Evaluation metrics remained statistically equivalent

I would greatly appreciate it if you could carefully consider this PR. Thank you!

issuse 3538_0 issuse 3539 issuse 3566

@mingxuetian
Copy link
Author

@qgallouedec, gentle ping. This PR directly addresses the problem you acknowledged in issue #3539, which is also a prerequisite for #3566. It provides a​ solution for the​ problem you confirmed needs fixing. A quick decision on this would be appreciated.

@qgallouedec
Copy link
Member

Thanks for the PR, we have limited bandwidth, please be patient

@mingxuetian
Copy link
Author

mingxuetian commented Nov 11, 2025

Thanks for the PR, we have limited bandwidth, please be patient

Thanks for the update. No problem, I truly understand the bandwidth constraints — appreciate you and the team's hard work. I'll stay patient, and please don't hesitate to reach out if you have any questions. Look forward to your review when time permits.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

GRPOTrainer - Repeat Sampler - _get_eval_sampler

2 participants