Skip to content

Conversation

@YushunXiang
Copy link
Contributor

@YushunXiang YushunXiang commented Jun 30, 2025

The issue about this bug is #1406, which is probably caused by huggingface/transformers#37033. In the v4.52.1 release of the transformers library, huggingface/transformers#37033 introduced a bug by renaming the class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixin) to class PaliGemmaModel(PaliGemmaPreTrainedModel).

This pull request introduces enhancements to the PI0Policy class in lerobot/common/policies/pi0/modeling_pi0.py to improve model state handling. The changes include adding a method to transform state dictionary keys and a class method to load model weights as safetensor files, ensuring compatibility with expected model structures. Solved #1406.

Enhancements to model state handling:

  • Key transformation for state dictionaries: Added _transform_state_dict_keys method to modify state dictionary keys for compatibility with expected model structure. This includes specific transformations for PaliGemma components to ensure proper mapping of model layers.

  • Support for safetensor file loading: Introduced _load_as_safetensor class method to load model weights from safetensor files. This method applies the key transformations before loading the state dictionary into the model.

  • Apply transformations for PaliGemma components

    • model.paligemma_with_expert.paligemma.language_model.lm_head -> model.paligemma_with_expert.paligemma.lm_head
    • model.paligemma_with_expert.paligemma.language_model.model -> model.paligemma_with_expert.paligemma.model.language_model
    • model.paligemma_with_expert.paligemma.vision_tower -> model.paligemma_with_expert.paligemma.model.vision_tower
    • model.paligemma_with_expert.paligemma.multi_modal_projector -> model.paligemma_with_expert.paligemma.model.multi_modal_projector

Environment

  • transformers: 4.53.0


@Cadene, @mshukor

Copilot AI review requested due to automatic review settings June 30, 2025 22:06

This comment was marked as outdated.

@OpenJarvisAI
Copy link

Hi, can u consider make it more robust? I suppose transformers might renaming them back, then your patch would fail again.

@YushunXiang
Copy link
Contributor Author

Hi, can u consider make it more robust? I suppose transformers might renaming them back, then your patch would fail again.

I will make it rubost in the feature.

@OpenJarvisAI
Copy link

@YushunXiang Hi, actually your code still missed some params:

Missing key(s) in state_dict: "normalize_inputs.buffer_observation_state.mean", "normalize_inputs.buffer_observation_state.std", "normalize_targets.buffer_action.mean", "normalize_targets.buffer_action.std", "unnormalize_outputs.buffer_action.mean", "unnormalize_outputs.buffer_action.std", "model.paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight".

the last one, did u noticed that?

@YushunXiang
Copy link
Contributor Author

@YushunXiang Hi, actually your code still missed some params:

Missing key(s) in state_dict: "normalize_inputs.buffer_observation_state.mean", "normalize_inputs.buffer_observation_state.std", "normalize_targets.buffer_action.mean", "normalize_targets.buffer_action.std", "unnormalize_outputs.buffer_action.mean", "unnormalize_outputs.buffer_action.std", "model.paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight".

the last one, did u noticed that?

I have noticed that. But I don't think it's a mapping problem, if there's a missing keys like that, shouldn't there also be a corresponding unexpected keys? That's a good question. I will look into it later. Can you help me to figure out the problem together?

@OpenJarvisAI
Copy link

Sure, since I am still can not make it work, and you did.
I think I need align with you first.

My question is, training with 8 GPUs not work, the policy loss goes down to about 0.06, and not increasing any longer

@YushunXiang
Copy link
Contributor Author

YushunXiang commented Jul 7, 2025

Sure, since I am still can not make it work, and you did. I think I need align with you first.

My question is, training with 8 GPUs not work, the policy loss goes down to about 0.06, and not increasing any longer

Here are my training loss curves. batchsize=16

w/ this PR

  • Loss: image
  • Learning Rate: image
  • L2 Loss: image

The lowest loss value is about 0.002.

w/o this PR

  • Loss: image

The lowest loss value is about 0.012.

@OpenJarvisAI
Copy link

@YushunXiang Using single GPU, is this normal?

2025-07-07 19:03:23.939 | INFO     | __main__:train:230 - Checkpoint policy after step 18200
2025-07-07 19:03:52.801 | INFO     | __main__:train:221 - step:18K smpl:255K ep:911 epch:2.28 loss:0.004 grdn:0.206 lr:1.0e-05 updt_s:0.524 data_s:0.000
2025-07-07 19:04:03.350 | INFO     | __main__:train:221 - step:18K smpl:255K ep:912 epch:2.28 loss:0.004 grdn:0.220 lr:1.0e-05 updt_s:0.525 data_s:0.000
2025-07-07 19:04:13.922 | INFO     | __main__:train:221 - step:18K smpl:256K ep:913 epch:2.28 loss:0.004 grdn:0.200 lr:1.0e-05 updt_s:0.526 data_s:0.000
2025-07-07 19:04:24.491 | INFO     | __main__:train:221 - step:18K smpl:256K ep:914 epch:2.28 loss:0.004 grdn:0.220 lr:1.0e-05 updt_s:0.526 data_s:0.000
2025-07-07 19:04:35.060 | INFO     | __main__:train:221 - step:18K smpl:256K ep:915 epch:2.29 loss:0.003 grdn:0.199 lr:9.9e-06 updt_s:0.526 data_s:0.000

Looks like, multiple GPUs can not decrease after 0.006

@OpenJarvisAI
Copy link

@YushunXiang DO u know how to set lr for multipl GPUS?

accelerate launch train_multi.py \
  --policy.path=$MODEL_PATH \
  --policy.optimizer_lr=$LR \
  --dataset.repo_id=$DATA_PATH \
  --dataset.image_transforms.enable=true \
  --dataset.image_transforms.random_order=true \
  --output_dir=outputs/$POLICY_TYPE/$EXP_NAME-$DATASET_NAME-$DATE \
  --batch_size=$GLOBAL_BATCH_SIZE \
  --steps=30000 \
  --log_freq=20 \
  --save_freq=100

Am confused, why the args will throw error:

train_multi.py: error: unrecognized arguments: --optimizer_lr=0.00016
usage: train_multi.py [-h] [--config_path str] [--dataset str] [--dataset.repo_id str] [--dataset.root str] [--dataset.episodes str] [--image_transforms str]
                      [--dataset.image_transforms.enable str] [--dataset.image_transforms.max_num_transforms str] [--dataset.image_transforms.random_order str]
                      [--dataset.image_transforms.tfs str] [--dataset.revision str] [--dataset.use_imagenet_stats str] [--dataset.video_backend str] [--env str]
                      [--env.type {aloha,pusht,xarm}] [--env.task str] [--env.fps str] [--env.features str] [--env.features_map str] [--env.episode_length str]
                      [--env.obs_type str] [--env.render_mode str] [--env.visualization_width str] [--env.visualization_height str] [--policy str]
                      [--policy.type {pi0,smolvla,pi0fast}] [--policy.attention_implementation str] [--policy.num_steps str] [--policy.train_expert_only str]
                      [--policy.train_state_proj str] [--policy.optimizer_grad_clip_norm str] [--policy.vlm_model_name str] [--policy.load_vlm_weights str]
                      [--policy.add_image_special_tokens str] [--policy.attention_mode str] [--policy.prefix_length str] [--policy.pad_language_to str]
                      [--policy.num_expert_layers str] [--policy.num_vlm_layers str] [--policy.self_attn_every_n_layers str] [--policy.expert_width_multiplier str]
                      [--policy.min_period str] [--policy.max_period str] [--policy.n_obs_steps str] [--policy.normalization_mapping str] [--policy.input_features str]
                      [--policy.output_features str] [--policy.device str] [--policy.use_amp str] [--policy.gradient_accumulation_steps str] [--policy.chunk_size str]
                      [--policy.n_action_steps str] [--policy.max_state_dim str] [--policy.max_action_dim str] [--policy.resize_imgs_with_padding str]
                      [--policy.interpolate_like_pi str] [--policy.empty_cameras str] [--policy.adapt_to_pi_aloha str] [--policy.use_delta_joint_actions_aloha str]
                      [--policy.tokenizer_max_length str] [--policy.proj_width str] [--policy.max_decoding_steps str] [--policy.fast_skip_tokens str]
                      [--policy.max_input_seq_len str] [--policy.use_cache str] [--policy.freeze_vision_encoder str] [--policy.freeze_lm_head str] [--policy.optimizer_lr str]
                      [--policy.optimizer_betas str] [--policy.optimizer_eps str] [--policy.optimizer_weight_decay str] [--policy.scheduler_warmup_steps str]
                      [--policy.scheduler_decay_steps str] [--policy.scheduler_decay_lr str] [--policy.checkpoint_path str] [--policy.padding_side str] [--policy.precision str]
                      [--policy.grad_clip_norm str] [--policy.relaxed_action_decoding str] [--output_dir str] [--job_name str] [--resume str] [--seed str] [--num_workers str]
                      [--batch_size str] [--steps str] [--eval_freq str] [--log_freq str] [--save_checkpoint str] [--save_freq str] [--use_policy_training_preset str]
                      [--optimizer str] [--optimizer.type {adam,adamw,sgd}] [--optimizer.betas str] [--optimizer.eps str] [--optimizer.lr str] [--optimizer.weight_decay str]
                      [--optimizer.grad_clip_norm str] [--optimizer.momentum str] [--optimizer.dampening str] [--optimizer.nesterov str] [--scheduler str]
                      [--scheduler.type {diffuser,vqbet,cosine_decay_with_warmup}] [--scheduler.name str] [--scheduler.num_vqvae_training_steps str] [--scheduler.num_cycles str]
                      [--scheduler.num_warmup_steps str] [--scheduler.num_decay_steps str] [--scheduler.peak_lr str] [--scheduler.decay_lr str] [--eval str]
                      [--eval.n_episodes str] [--eval.batch_size str] [--eval.use_async_envs str] [--wandb str] [--wandb.enable str] [--wandb.disable_artifact str]
                      [--wandb.project str] [--wandb.entity str] [--wandb.notes str] [--wandb.run_id str] [--wandb.mode str]
train_multi.py: error: unrecognized arguments: --optimizer_lr=0.00016

@YushunXiang
Copy link
Contributor Author

@YushunXiang DO u know how to set lr for multipl GPUS?

accelerate launch train_multi.py \
  --policy.path=$MODEL_PATH \
  --policy.optimizer_lr=$LR \
  --dataset.repo_id=$DATA_PATH \
  --dataset.image_transforms.enable=true \
  --dataset.image_transforms.random_order=true \
  --output_dir=outputs/$POLICY_TYPE/$EXP_NAME-$DATASET_NAME-$DATE \
  --batch_size=$GLOBAL_BATCH_SIZE \
  --steps=30000 \
  --log_freq=20 \
  --save_freq=100

Am confused, why the args will throw error:

train_multi.py: error: unrecognized arguments: --optimizer_lr=0.00016
usage: train_multi.py [-h] [--config_path str] [--dataset str] [--dataset.repo_id str] [--dataset.root str] [--dataset.episodes str] [--image_transforms str]
                      [--dataset.image_transforms.enable str] [--dataset.image_transforms.max_num_transforms str] [--dataset.image_transforms.random_order str]
                      [--dataset.image_transforms.tfs str] [--dataset.revision str] [--dataset.use_imagenet_stats str] [--dataset.video_backend str] [--env str]
                      [--env.type {aloha,pusht,xarm}] [--env.task str] [--env.fps str] [--env.features str] [--env.features_map str] [--env.episode_length str]
                      [--env.obs_type str] [--env.render_mode str] [--env.visualization_width str] [--env.visualization_height str] [--policy str]
                      [--policy.type {pi0,smolvla,pi0fast}] [--policy.attention_implementation str] [--policy.num_steps str] [--policy.train_expert_only str]
                      [--policy.train_state_proj str] [--policy.optimizer_grad_clip_norm str] [--policy.vlm_model_name str] [--policy.load_vlm_weights str]
                      [--policy.add_image_special_tokens str] [--policy.attention_mode str] [--policy.prefix_length str] [--policy.pad_language_to str]
                      [--policy.num_expert_layers str] [--policy.num_vlm_layers str] [--policy.self_attn_every_n_layers str] [--policy.expert_width_multiplier str]
                      [--policy.min_period str] [--policy.max_period str] [--policy.n_obs_steps str] [--policy.normalization_mapping str] [--policy.input_features str]
                      [--policy.output_features str] [--policy.device str] [--policy.use_amp str] [--policy.gradient_accumulation_steps str] [--policy.chunk_size str]
                      [--policy.n_action_steps str] [--policy.max_state_dim str] [--policy.max_action_dim str] [--policy.resize_imgs_with_padding str]
                      [--policy.interpolate_like_pi str] [--policy.empty_cameras str] [--policy.adapt_to_pi_aloha str] [--policy.use_delta_joint_actions_aloha str]
                      [--policy.tokenizer_max_length str] [--policy.proj_width str] [--policy.max_decoding_steps str] [--policy.fast_skip_tokens str]
                      [--policy.max_input_seq_len str] [--policy.use_cache str] [--policy.freeze_vision_encoder str] [--policy.freeze_lm_head str] [--policy.optimizer_lr str]
                      [--policy.optimizer_betas str] [--policy.optimizer_eps str] [--policy.optimizer_weight_decay str] [--policy.scheduler_warmup_steps str]
                      [--policy.scheduler_decay_steps str] [--policy.scheduler_decay_lr str] [--policy.checkpoint_path str] [--policy.padding_side str] [--policy.precision str]
                      [--policy.grad_clip_norm str] [--policy.relaxed_action_decoding str] [--output_dir str] [--job_name str] [--resume str] [--seed str] [--num_workers str]
                      [--batch_size str] [--steps str] [--eval_freq str] [--log_freq str] [--save_checkpoint str] [--save_freq str] [--use_policy_training_preset str]
                      [--optimizer str] [--optimizer.type {adam,adamw,sgd}] [--optimizer.betas str] [--optimizer.eps str] [--optimizer.lr str] [--optimizer.weight_decay str]
                      [--optimizer.grad_clip_norm str] [--optimizer.momentum str] [--optimizer.dampening str] [--optimizer.nesterov str] [--scheduler str]
                      [--scheduler.type {diffuser,vqbet,cosine_decay_with_warmup}] [--scheduler.name str] [--scheduler.num_vqvae_training_steps str] [--scheduler.num_cycles str]
                      [--scheduler.num_warmup_steps str] [--scheduler.num_decay_steps str] [--scheduler.peak_lr str] [--scheduler.decay_lr str] [--eval str]
                      [--eval.n_episodes str] [--eval.batch_size str] [--eval.use_async_envs str] [--wandb str] [--wandb.enable str] [--wandb.disable_artifact str]
                      [--wandb.project str] [--wandb.entity str] [--wandb.notes str] [--wandb.run_id str] [--wandb.mode str]
train_multi.py: error: unrecognized arguments: --optimizer_lr=0.00016

You should use --policy.optimizer_lr instead of --optimizer_lr.

@OpenJarvisAI
Copy link

Hi, I tried --policy.optimizer_lr, but somehow the draccus didn't parsed it correclty.

So confused.

Also, I found the trained model, by loading it back, still throw an error:

Missing keys: in state_dict: "model.paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight".

Have u tried by strict = True when loading the trained model? It will still throw this error.

@YushunXiang
Copy link
Contributor Author

YushunXiang commented Jul 8, 2025

Hi, I tried --policy.optimizer_lr, but somehow the draccus didn't parsed it correclty.

So confused.

class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
"""
Base configuration class for policy models.
Args:
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
current step and additional steps going back).
input_shapes: A dictionary defining the shapes of the input data for the policy.
output_shapes: A dictionary defining the shapes of the output data for the policy.
input_normalization_modes: A dictionary with key representing the modality and the value specifies the
normalization mode to apply.
output_normalization_modes: Similar dictionary as `input_normalization_modes`, but to unnormalize to
the original scale.
"""
n_obs_steps: int = 1
normalization_mapping: dict[str, NormalizationMode] = field(default_factory=dict)
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
output_features: dict[str, PolicyFeature] = field(default_factory=dict)
device: str | None = None # cuda | cpu | mp
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
# automatic gradient scaling is used.
use_amp: bool = False
push_to_hub: bool = True
repo_id: str | None = None
# Upload on private repository on the Hugging Face hub.
private: bool | None = None
# Add tags to your policy on the hub.
tags: list[str] | None = None
# Add tags to your policy on the hub.
license: str | None = None

does not contain optimizer_lr, which causes this error.

Without modifying the source code, I think it's a good idea to change the value of optimizer_lr in lerobot/pi0/config.json

Also, I found the trained model, by loading it back, still throw an error:

Missing keys: in state_dict: "model.paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight".

Have u tried by strict = True when loading the trained model? It will still throw this error.

I have tried, and the error message is the same as you.

@YushunXiang
Copy link
Contributor Author

YushunXiang commented Jul 8, 2025

@YushunXiang Hi, actually your code still missed some params:

Missing key(s) in state_dict: "normalize_inputs.buffer_observation_state.mean", "normalize_inputs.buffer_observation_state.std", "normalize_targets.buffer_action.mean", "normalize_targets.buffer_action.std", "unnormalize_outputs.buffer_action.mean", "unnormalize_outputs.buffer_action.std", "model.paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight".

the last one, did u noticed that?

Don't worry about that. The model.paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight is the same as model.paligemma_with_expert.paligemma.lm_head.weight. You can check the torch.Tensor.untyped_storage().data_ptr() and torch.Tensor.untyped_storage().nbytes() and you will find the memory address and memory byte size are the same, proving that the two tensors are sharing the same underlying memory.

When I was reading the source code of Transformers, I found that one of the components of pi0, GemmaForCausalLM, which is the model.paligemma_with_expert.paligemma.model.language_model, has tie_word_embeddings=True in set in its config file. When a model has both get_input_embeddings() and get_output_embeddings() methods defined, the Transformers framework will automatically tie them together in the tie_weights() method.

    def tie_weights(self):
        """
        Tie the weights between the input embeddings and the output embeddings.

        If the `torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning the
        weights instead.
        """
        if getattr(self.config.get_text_config(decoder=True), "tie_word_embeddings", True):
            output_embeddings = self.get_output_embeddings()
            if output_embeddings is not None:
                self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())

        if getattr(self.config, "is_encoder_decoder", False) and getattr(self.config, "tie_encoder_decoder", False):
            if hasattr(self, self.base_model_prefix):
                self = getattr(self, self.base_model_prefix)
            tied_weights = self._tie_encoder_decoder_weights(
                self.encoder, self.decoder, self.base_model_prefix, "encoder"
            )
            # Setting a dynamic variable instead of `_tied_weights_keys` because it's a class
            # attributed not an instance member, therefore modifying it will modify the entire class
            # Leading to issues on subsequent calls by different tests or subsequent calls.
            self._dynamic_tied_weights_keys = tied_weights

        for module in self.modules():
            if hasattr(module, "_tie_weights"):
                module._tie_weights()
    def get_input_embeddings(self):
        return self.model.embed_tokens

    def get_output_embeddings(self):
        return self.lm_head

This means that lm_head.weight and embed_tokens.weight are equivalent.

@OpenJarvisAI

@YushunXiang
Copy link
Contributor Author

I have a question. Convert the state_dict() key name of the model itself at load time to match the checkpoints, instead of converting the state_dict loaded from a file. Is this a more elegant approach?

@YushunXiang YushunXiang requested a review from Copilot July 8, 2025 20:09
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR fixes checkpoint state mismatches for the PI0Policy by transforming state dict keys and adds support for loading weights from safetensor files.

  • Adds a key transformation method to align PaliGemma layer names.
  • Introduces a safetensor loader that applies these transformations before model loading.
Comments suppressed due to low confidence (2)

src/lerobot/policies/pi0/modeling_pi0.py:262

  • [nitpick] Add more specific type annotations (e.g., Dict[str, torch.Tensor]) for the input and return values to improve code clarity and editor support.
    def _transform_state_dict_keys(cls, state_dict: dict) -> dict:

src/lerobot/policies/pi0/modeling_pi0.py:261

  • There’s no test coverage for the key-transformation logic; consider adding unit tests that verify each mapping and the tied-weights handling.
    @classmethod

@pkooij pkooij requested a review from michel-aractingi July 20, 2025 08:42
@pkooij pkooij added the policies Items related to robot policies label Jul 20, 2025
Copy link
Collaborator

@michel-aractingi michel-aractingi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this PR! I left a couple of comments

@michel-aractingi michel-aractingi merged commit 71eff18 into huggingface:main Jul 30, 2025
12 checks passed
@michel-aractingi
Copy link
Collaborator

Thank your for this fix! it is merged now,

@branyang02
Copy link

I mean this is kinda ugly not gonna lie...

Is it possible to either:

  1. cap the transformers library version.
  2. copy all PaliGemma and related code over so we have a fixed implementation?

I suppose copying PaliGemma code over is somewhat complicated, but I do think we will benefit from not having to worry about future transformers library updates, as well as look into ways to speed things up, as well as mess with floating point precisions like how openpi does.

@michel-aractingi
Copy link
Collaborator

michel-aractingi commented Jul 31, 2025

@branyang02 This is a temporary fix until we merge the pipeline pr #1431
Which will do exactly what you suggested, bump the transformers and change the model keys directly.

@YushunXiang
Copy link
Contributor Author

@branyang02 Thank you for your advice. My code is indeed not elegant enough.

PR 1431 is a wonderful work, and I have learned a lot from it.

@YushunXiang YushunXiang deleted the fix-pi0 branch July 31, 2025 11:46
Maelic pushed a commit to Maelic/lerobot that referenced this pull request Aug 4, 2025
@ymy1946676292
Copy link

ymy1946676292 commented Aug 5, 2025

Sure, since I am still can not make it work, and you did. I think I need align with you first.
My question is, training with 8 GPUs not work, the policy loss goes down to about 0.06, and not increasing any longer

Here are my training loss curves. batchsize=16

w/ this PR

  • Loss: image
  • Learning Rate: image
  • L2 Loss: image

The lowest loss value is about 0.002.

w/o this PR

  • Loss: image

The lowest loss value is about 0.012.

Thanks for your work on LeRobot and for sharing the training configurations. However, when I try to reproduce training with the pi0 policy on the Libero dataset using the same config, I notice that the training loss remains consistently high in the early stage even on a single GPU setup.

Here are the details:

🔧 Training Configuration

'dataset': {
  'root': '/home/Program/lerobot_new/datasets/libero_10_no_noops_1.0.0_lerobot',
  'video_backend': 'torchcodec',
  'use_imagenet_stats': True,
  'image_transforms': {
    'enable': True,
    'max_num_transforms': 3,
    'random_order': True,
    'tfs': {
      'brightness': {'type': 'ColorJitter', 'weight': 1.0, 'kwargs': {'brightness': [0.8, 1.2]}},
      'contrast':   {'type': 'ColorJitter', 'weight': 1.0, 'kwargs': {'contrast': [0.8, 1.2]}},
      'hue':        {'type': 'ColorJitter', 'weight': 1.0, 'kwargs': {'hue': [-0.05, 0.05]}},
      'saturation': {'type': 'ColorJitter', 'weight': 1.0, 'kwargs': {'saturation': [0.5, 1.5]}},
      'sharpness':  {'type': 'SharpnessJitter', 'weight': 1.0, 'kwargs': {'sharpness': [0.5, 1.5]}}
    }
  }
},
'policy': {
  'type': 'pi0',
  'n_obs_steps': 1,
  'n_action_steps': 50,
  'chunk_size': 50,
  'proj_width': 1024,
  'tokenizer_max_length': 48,
  'freeze_vision_encoder': True,
  'train_state_proj': True,
  'resize_imgs_with_padding': [224, 224],
  'normalization_mapping': {
    'STATE': 'MEAN_STD',
    'ACTION': 'MEAN_STD',
    'VISUAL': 'IDENTITY'
  },
  'scheduler_decay_steps': 30000,
  'scheduler_warmup_steps': 1000,
  'scheduler_decay_lr': 2.5e-6,
  'optimizer_lr': 1e-4,
  'optimizer_weight_decay': 1e-10,
  'optimizer_betas': [0.9, 0.95]
},
'scheduler': {
  'type': 'cosine_decay_with_warmup',
  'peak_lr': 1e-4,
  'num_decay_steps': 30000,
  'num_warmup_steps': 1000,
  'decay_lr': 2.5e-6
},
'optimizer': {
  'type': 'adamw',
  'lr': 1e-4,
  'betas': [0.9, 0.95],
  'eps': 1e-8,
  'weight_decay': 1e-10
},
'use_policy_training_preset': True,
'device': 'cuda',
'use_amp': False,
'steps': 30000,
'log_freq': 10

Transformers version: 4.53.0

📉 Loss Output Sample

Here’s a snippet of the logs:

step:10  loss:0.160
step:20  loss:0.176
step:50  loss:0.176
step:100 loss:0.126
step:150 loss:0.127
step:200 loss:0.108
step:250 loss:0.098
step:300 loss:0.095
step:320 loss:0.101
step:1K loss:0.108
step:2K loss:0.090

Any insight or clarification would be appreciated! Thanks 🙏

@YushunXiang
Copy link
Contributor Author

@ymy1946676292
Could you give me more log steps or loss curve graphs (both w/ PR and w/o PR) to determine whether it has converged?

@ymy1946676292
Copy link

@ymy1946676292 Could you give me more log steps or loss curve graphs (both w/ PR and w/o PR) to determine whether it has converged?

Okay, below is the training loss curve of 160,000, after multiple rounds of training, the loss stabilizes at about 0.05
w/o PR
image
w/ PR
Currently, the training has reached around 20k, the loss has decreased to about 0.05.
However, after referring to the training configuration provided in [Issue #952](#952), the model is expected to exhibit a relatively low loss at the initial stage. However, even with exactly the same configuration, the loss I obtained during training remained very high, and it did not significantly decrease even after prolonged training.

@YushunXiang
Copy link
Contributor Author

@ymy1946676292 I guess that #952 may have introduced this problem?

@ymy1946676292
Copy link

@ymy1946676292 I guess that #952 may have introduced this problem?

Thank you very much for your answer, I have tried to fix it using the scheme mentioned in #952, but after multiple rounds of training, the loss is still very high and the success rate is almost 0

AdilZouitine pushed a commit that referenced this pull request Aug 10, 2025
milong26 pushed a commit to milong26/lerobot_diy that referenced this pull request Aug 26, 2025
Ricci084 pushed a commit to JeffWang987/lerobot that referenced this pull request Sep 5, 2025
BillmanH pushed a commit to BillmanH/lerobot that referenced this pull request Sep 7, 2025
fracapuano pushed a commit that referenced this pull request Sep 12, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

policies Items related to robot policies

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants