Marius Memmel*, Jacob Berg*, Bingqing Chen, Abhishek Gupta†, Jonathan Francis†
*equal contribution †equal advising
ICLR 2025
paper | arxiv | website | policy learning code
STRAP is a retrieval method that robustifies few-shot imitation learning by encoding trajectories with vision foundation models and retrieving sub-trajectories with subsequence dynamic time warping.
This repository contains our implementation of the trajectory encoding and the retrieval algorithm.
- Create the conda environment:
conda create -n strap python=3.10 conda activate strap
- Install the repository:
git clone https://github.com/WEIRDLabUW/STRAP.git cd STRAP pip install -e .
You're all set!
Code tested with Python 3.10, PyTorch 2.5.1, transformers 4.48.1, and CUDA version 12.4
To replicate our results on the LIBERO datasets, first download the datasets using the download_libero.py script.
python data/download_libero.pyNext, encode the datasets using encode_datasets.py. By default, the script encodes the agentview_rgb observations in LIBERO-10 and LIBERO-90 using DINOv2. This process can take up to multiple hours depending on your hardware!
python strap/embedding/encode_datasets.pyFinally, run retrieval.py to retrieve sub-trajectories from the offline dataset. By default, the script selects 3 demos from the "put both moka pots on the stove" task in LIBERO-10 (prior dataset) and retrieves the top 100 sub-trajectories from LIBERO-90 (offline dataset). The sub-trajectories have a minimum length of 20 and are retrieved using the DINOv2 embeddings of agentview_rgb.
python strap/retrieval/retrieval.pyThe retrieval dataset put_both_moka_pots_retrieved_dataset.hdf5 is saved in the data/retrieval_results folder! You can now use this dataset to train a policy using our
policy learning code.
We designed STRAP to be modular and deal with any dataset roughly following the hdf5 structure of robomimic. The embedding code doesn't modify the original dataset and the retrieval algorithm writes a single hdf5 file containing the retrieved sub-trajectories. To retrieve from a custom dataset, follow these steps:
-
Configure a new dataset
To add a new dataset, create a
HDF5FileStructureand twoDatasetConfig.HDF5FileStructuredefines the file structure for the new dataset, e.g., which image observations and proprioceptive information to use, file paths and endings, and data keys.DatasetConfigdefines which files to load and helper functions to write the output dataset. You'll have to add two datasets (prior and offline).
See configs/libero_hdf5_config.py for an example.
Tipp: If your dataset follows the LIBERO format, you can reuse the helper functions in configs/libero_file_functions.py.
Configuration details: here
-
Embed the datasets
Define encoders (
get_encoders) and add the new dataset configs (get_datasets) in embedding/encode_datasets.py.Run the script to embed the dataset:
python strap/embedding/encode_datasets.py
Configuration details: here
-
Retrieve from the dataset
Define the retrieval arguments (
get_args) in retrieval/retrieval.py.Run the script to retrieve from the dataset:
python strap/retrieval/retrieval.py
Configuration details: here
├── STRAP
│ ├── data/ # Folder to contain the data
│ │ ├── download_libero.py # Script to download libero datasets
│ ├── strap/
│ │ ├── retrieval/
│ │ │ ├── retrieval.py # Script to run retrieval on embeded datasets using the retrieval model
│ │ ├── embedding/
│ │ │ ├── encoders/
│ │ │ │ ├── encoders.py # Definitions of the different encoders
│ │ │ │ ├── encode_datasets.py # Script to encode a dataset using an encoder for retrieval.
│ │ │ ├── configs/
│ │ │ │ ├── libero_configs.py # Configs for the libero datasets
│ │ │ │ ├── libero_file_functions.py # File functions for the libero datasets
│ │ ├── README.md
│ ├── requirements.txt@article{memmel2024strap,
title={STRAP: Robot Sub-Trajectory Retrieval for Augmented Policy Learning},
author={Memmel, Marius and Berg, Jacob and Chen, Bingqing and Gupta, Abhishek and Francis, Jonathan},
journal={arXiv preprint arXiv:2412.15182},
year={2024}
}
