Skip to content

Conversation

@KleistvonLiu
Copy link
Contributor

What this does

The original implementation of Beta distribution for pi0 seems wrong. This PR compares sample results from three implementation (original openpi, previous lerobot and fixed), and shows that the fixed implementation is aligned to original openpi implementation and also the plot from Pi0 paper.

As shown in Pi0 paper, we sample time from Beta distribution and hope to emphasize timesteps close to noise action (in the Pi0's paper time close to 0, in Pi0's code time close to 1).
image

The following code compares the sample result from three implementation:

import jax
import torch
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

def sample_beta(alpha, beta, bsize):
    gamma1 = torch.empty((bsize,)).uniform_(0, 1).pow(1 / alpha)
    gamma2 = torch.empty((bsize,)).uniform_(0, 1).pow(1 / beta)
    return gamma1 / (gamma1 + gamma2)

def sample_time(bsize):
    time_beta = sample_beta(1.5, 1.0, bsize)
    time = time_beta * 0.999 + 0.001
    return time.to(dtype=torch.float32)

batch_size = 10_000

# ① previous implementation from lerobot
time_samples1 = sample_time(batch_size).cpu().numpy()

# ② original openpi from pi0 (JAX)
seed = 42
rng  = jax.random.PRNGKey(seed)
rng, time_rng = jax.random.split(rng)
time2 = jax.random.beta(time_rng, 1.5, 1.0, (batch_size,)) * 0.999 + 0.001
time_samples2 = np.array(time2)  # JAX → NumPy

# ③ fixed implementation from lerobot (PyTorch Beta)
beta_dist = torch.distributions.Beta(concentration1=1.5, concentration0=1.0)
time_samples3 = (beta_dist.sample((batch_size,)) * 0.999 + 0.001).cpu().numpy()

# ---- Plot three subplots ----
fig, axes = plt.subplots(3, 1, figsize=(10, 9), sharex=True, sharey=True)

bins = 200
kwargs = dict(bins=bins, density=True, edgecolor="black", alpha=0.7)

axes[0].hist(time_samples1, **kwargs)
axes[0].set_title("Lerobot previous: Beta(1.5,1.0) scaled to (0.001,1.0)")

axes[1].hist(time_samples2, **kwargs)
axes[1].set_title("OpenPI (JAX): Beta(1.5,1.0) scaled to (0.001,1.0)")

axes[2].hist(time_samples3, **kwargs)
axes[2].set_title("Fixed PyTorch: Beta(1.5,1.0) scaled to (0.001,1.0)")

for ax in axes:
    ax.grid(True, alpha=0.3)
axes[-1].set_xlabel("t")
for ax in axes:
    ax.set_ylabel("Density")

fig.tight_layout()
fig.savefig("beta_hist_3subplots.png", dpi=200)

By running above code we can get following figure, which shows that the previous lerobot implementation is not as expected while the fixed one works well.
beta_hist_3subplots

Copy link
Collaborator

@fracapuano fracapuano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @KleistvonLiu thank you so much for the PR, and what a nice catch ⭐

Happy to get this merged if you can:

  • (1) maintain sampling encapsulated (sample_beta is better than self.beta_distribution in terms of statefulness of the policy object)

Again, thank you so very much! 🙏

@pkooij pkooij added bug Something isn’t working correctly policies Items related to robot policies labels Jul 28, 2025
@KleistvonLiu
Copy link
Contributor Author

Hey @KleistvonLiu thank you so much for the PR, and what a nice catch ⭐

Happy to get this merged if you can:

  • (1) maintain sampling encapsulated (sample_beta is better than self.beta_distribution in terms of statefulness of the policy object)

Again, thank you so very much! 🙏

Hi @fracapuano thank you for you review and reply ! I have modified the code and I hope it work well, please review it again. 🙏

Copy link
Collaborator

@fracapuano fracapuano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thank you, and what a nice catch @KleistvonLiu

@fracapuano fracapuano merged commit 4b88842 into huggingface:main Jul 28, 2025
8 checks passed
milong26 added a commit to milong26/lerobot_diy that referenced this pull request Jul 29, 2025
* fix bug about sampling t from beta distribution

* fix: address review comments

---------
AdilZouitine pushed a commit that referenced this pull request Aug 10, 2025
* fix bug about sampling t from beta distribution

* fix: address review comments

---------
milong26 pushed a commit to milong26/lerobot_diy that referenced this pull request Aug 26, 2025
* fix bug about sampling t from beta distribution

* fix: address review comments

---------
milong26 added a commit to milong26/lerobot_diy that referenced this pull request Aug 26, 2025
* fix bug about sampling t from beta distribution

* fix: address review comments

---------
Ricci084 pushed a commit to JeffWang987/lerobot that referenced this pull request Sep 5, 2025
* fix bug about sampling t from beta distribution

* fix: address review comments

---------
BillmanH pushed a commit to BillmanH/lerobot that referenced this pull request Sep 7, 2025
* fix bug about sampling t from beta distribution

* fix: address review comments

---------
fracapuano pushed a commit that referenced this pull request Sep 12, 2025
* fix bug about sampling t from beta distribution

* fix: address review comments

---------
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn’t working correctly policies Items related to robot policies

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants