Skip to content

[RLlib][PPO new-API] Large discrepancy between Algorithm.evaluate() and manual inference via restored EnvToModule/ModuleToEnv pipelines on CarRacing-v3 #53588

@lukaskiss222

Description

@lukaskiss222

What happened + What you expected to happen

Following policy_inference example implementaion and comparing it to the results to of evaluation does not result in the same performance.

I rerun the eval and inference 1000 times on the same finished checkpoint. In the picture, you can see the results:
Image

Versions / Dependencies

Python 3.13.3
ray '2.46.0'
torch '2.7.0.dev20250302+cu128'
Linux fedora-desktop (OS: Fedora release 42 (Adams) x86_64) 6.14.9-300.fc42.x86_64 #1 SMP PREEMPT_DYNAMIC Thu May 29 14:27:53 UTC 2025 x86_64 GNU/Linux
gymnasium 1.0.0
CPU: AMD Ryzen 9 7950X3D (32) @ 5.763GHz
GPU: NVIDIA GeForce RTX 4090

Reproduction script

Script is simplified and stops after it learns something for comparison.

import os, json

import ray
import numpy as np
from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.connectors.env_to_module import EnvToModulePipeline
from ray.rllib.connectors.module_to_env import ModuleToEnvPipeline
from ray.rllib.core import (
    COMPONENT_ENV_RUNNER,
    COMPONENT_ENV_TO_MODULE_CONNECTOR,
    COMPONENT_MODULE_TO_ENV_CONNECTOR,
    COMPONENT_LEARNER_GROUP,
    COMPONENT_LEARNER,
    COMPONENT_RL_MODULE,
    DEFAULT_MODULE_ID,
)
from ray.rllib.core.columns import Columns
from ray.rllib.core.rl_module.rl_module import RLModule
from ray.rllib.env.single_agent_episode import SingleAgentEpisode
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.metrics import (
    ENV_RUNNER_RESULTS,
    EPISODE_RETURN_MEAN,
)
from ray import tune
from ray.rllib.algorithms.ppo import PPOConfig
torch, _ = try_import_torch()
import gymnasium as gym
from ray.rllib.algorithms.ppo.torch.ppo_torch_learner import PPOTorchLearner
from ray.rllib.utils.numpy import convert_to_numpy
# Fix for saving checkpoints, https://github.com/ray-project/ray/issues/53467
class FixedPPOLearner(PPOTorchLearner):
    def get_state(
        self,
        *args,
        **kwargs,
    ):
        state = super().get_state(*args, **kwargs)
        if "metrics_logger" in state.keys():
            state["metrics_logger"] = convert_to_numpy(state["metrics_logger"])
        return state




class CarRacingFloat(gym.ObservationWrapper):
    """Observation wrapper for the CarRacing-v3 environment to convert
    the observation space to float32.
    """

    def __init__(self, env):
        super().__init__(env)
        self.observation_space = gym.spaces.Box(0, 255, (96, 96, 3), np.float32)
            

    def observation(self, observation):
        return observation.astype(np.float32)


tune.register_env(
    "FloatCarRacing-v3",
    lambda config: CarRacingFloat(
        gym.make(
            "CarRacing-v3",
        )
    ),
)

n_envs=8
n_steps=1024
total_iterations=200
ray.init()
if __name__ == "__main__":

    abs_path = os.path.abspath("ray_storage")
    explore_during_inference = False

    base_config = (
        PPOConfig()
        .environment(
            "FloatCarRacing-v3",
        ).env_runners(
            num_env_runners=1,
            num_envs_per_env_runner=n_envs,  # Number of parallel environments
            num_cpus_per_env_runner=n_envs,
            gym_env_vectorize_mode="ASYNC",
            num_gpus_per_env_runner=0.3,
            rollout_fragment_length=n_steps,  # Number of steps per environment per rollout
        )
        .training(
            learner_class=FixedPPOLearner,
            train_batch_size_per_learner=n_envs * n_steps,  # Total batch size per learner
            num_epochs=15,
            lr=0.0005,
            minibatch_size=128,
            gamma=0.995,
            lambda_=0.95,
            vf_loss_coeff=1.0,
            vf_clip_param=np.inf,  # No clipping for value function
            entropy_coeff=0.000,
            grad_clip=0.5,  # Gradient clipping
            grad_clip_by="global_norm",
        ).learners(
            num_learners=1,
            num_gpus_per_learner=0.5,
            num_cpus_per_learner=0,
        )
        .evaluation(
            evaluation_interval=total_iterations//5 + total_iterations % 5,
            evaluation_duration=n_envs,
            evaluation_duration_unit="episodes",
        )
    )
    tuner = tune.Tuner(
        "PPO",
        param_space=base_config.to_dict(),
        run_config=tune.RunConfig(
            verbose=1,
            name="ppo_carracing",
            stop={"training_iteration": total_iterations},
            storage_path=f"file://{abs_path}",
            checkpoint_config=tune.CheckpointConfig(
                num_to_keep=3,  # Keep the last 3 checkpoints
                checkpoint_frequency=10,  # Save a checkpoint every 10 iterations
                checkpoint_at_end=True,  # Always save a final checkpoint at the end of training
            ),
        ),
    )


    results = tuner.fit()

    # Get the last checkpoint from the above training run.
    metric_key = metric = f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}"
    best_result = results.get_best_result(metric=metric_key, mode="max")

    algo = Algorithm.from_checkpoint(
        best_result.checkpoint.path
    )
    sample_times = 1000
    eval_results = np.zeros(sample_times)
    our_eval = np.zeros(sample_times)
    for i in range(sample_times):
        return_mean = algo.evaluate()['env_runners']['episode_return_mean']
        eval_results[i] = return_mean

    print(
        "Training completed (R="
        f"{best_result.metrics[ENV_RUNNER_RESULTS][EPISODE_RETURN_MEAN]}). "
        "Creating an env-loop for inference ..."
    )

    env = algo.env_creator({})

    # Create the env-to-module pipeline from the checkpoint.
    print("Restore env-to-module connector from checkpoint ...", end="")
    env_to_module = EnvToModulePipeline.from_checkpoint(
        os.path.join(
            best_result.checkpoint.path,
            COMPONENT_ENV_RUNNER,
            COMPONENT_ENV_TO_MODULE_CONNECTOR,
        )
    )
    print(" ok")

    print("Restore RLModule from checkpoint ...", end="")
    # Create RLModule from a checkpoint.
    rl_module = RLModule.from_checkpoint(
        os.path.join(
            best_result.checkpoint.path,
            COMPONENT_LEARNER_GROUP,
            COMPONENT_LEARNER,
            COMPONENT_RL_MODULE,
            DEFAULT_MODULE_ID,
        )
    )
    print(" ok")

    # For the module-to-env pipeline, we will use the convenient config utility.
    print("Restore module-to-env connector from checkpoint ...", end="")
    module_to_env = ModuleToEnvPipeline.from_checkpoint(
        os.path.join(
            best_result.checkpoint.path,
            COMPONENT_ENV_RUNNER,
            COMPONENT_MODULE_TO_ENV_CONNECTOR,
        )
    )
    print(" ok")

    # Now our setup is complete:
    # [gym.Env] -> env-to-module -> [RLModule] -> module-to-env -> [gym.Env] ... repeat
    num_episodes = 0


    all_total_rewards = []
    for num_episodes in range(sample_times):
        obs, _ = env.reset()
        episode = SingleAgentEpisode(
            observations=[obs],
            observation_space=env.observation_space,
            action_space=env.action_space,
        )
        while not episode.is_done:
            shared_data = {}
            input_dict = env_to_module(
                episodes=[episode],  # ConnectorV2 pipelines operate on lists of episodes.
                rl_module=rl_module,
                explore=explore_during_inference,
                shared_data=shared_data,
            )
            # No exploration (using RLModule).
            if not explore_during_inference:
                rl_module_out = rl_module.forward_inference(input_dict)
            # W/ exploration (using RLModule).
            else:
                rl_module_out = rl_module.forward_exploration(input_dict)

            to_env = module_to_env(
                batch=rl_module_out,
                episodes=[episode],  # ConnectorV2 pipelines operate on lists of episodes.
                rl_module=rl_module,
                explore=explore_during_inference,
                shared_data=shared_data,
            )
            # Send the computed action to the env. Note that the RLModule and the
            # connector pipelines work on batched data (B=1 in this case), whereas the Env
            # is not vectorized here, so we need to use `action[0]`.
            action = to_env.pop(Columns.ACTIONS)[0]
            obs, reward, terminated, truncated, _ = env.step(action)
            # Keep our `SingleAgentEpisode` instance updated at all times.
            episode.add_env_step(
                obs,
                action,
                reward,
                terminated=terminated,
                truncated=truncated,
                # Same here: [0] b/c RLModule output is batched (w/ B=1).
                extra_model_outputs={k: v[0] for k, v in to_env.items()},
            )

            # Is the episode `done`? -> Reset.
            if episode.is_done:
                our_eval[num_episodes] = episode.get_return()
                print(f"Episode done: Total reward = {episode.get_return()}")
                all_total_rewards.append(episode.get_return())
                print(
                    f"Episode {num_episodes} done. "
                    f"Total reward: {episode.get_return()}. "
                    f"Mean total reward so far: {np.mean(all_total_rewards):.2f}"
                )

    print(f"Done performing action inference through {num_episodes} Episodes")
    # save the results as npy
    np.save("eval_results.npy", eval_results)
    np.save("our_eval.npy", our_eval)
    env.close()
    ray.shutdown()

Issue Severity

High: It blocks me from completing my task.

Metadata

Metadata

Assignees

No one assigned

    Labels

    P3Issue moderate in impact or severitybugSomething that is supposed to be working; but isn'trllibRLlib related issuesstabilityusability

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions