This is the official JAX implementation for the paper Mean Flows for One-step Generative Modeling. This code is written and tested on TPUs.
- 25.07.29 Release the Pytorch code for CIFAR-10.
- 25.07.29 JAX+GPU sanity check by this PR. Thanks to @Wenhao!
Run install.sh
to install the dependencies (JAX+TPUs).
You can quickly verify your setup with a provided MF-B/4 checkpoint.
-
Download the checkpoint and FID stats:
MF-B/4 Checkpoint (Google Drive)Note: The FID stats followed the ADM GitHub repository.
-
Unzip the checkpoint:
unzip <downloaded_file.zip> -d <your_ckpt_dir>
Replace
<downloaded_file.zip>
with the downloaded file name, and<your_ckpt_dir>
with your target checkpoint directory. -
Set up the config:
- Add
load_from
toconfigs/run_b4.yml
and set it to the path of<your_ckpt_dir>
. - Set
fid.cache_ref
to the path of the downloaded FID stats file. - Set
eval_only
toTrue
in the same config.
- Add
-
Launch evaluation:
bash scripts/launch.sh EVAL_JOB_NAME
The expected FID is 11.4 for this checkpoint.
Before training, you need to prepare the ImageNet dataset and compute latent representations:
Download the ImageNet dataset and extract it to your desired location. The dataset should have the following structure:
imagenet/
├── train/
│ ├── n01440764/
│ ├── n01443537/
│ └── ...
└── val/
├── n01440764/
├── n01443537/
└── ...
Update the data paths in scripts/prepare_data.sh
:
IMAGENET_ROOT="YOUR_IMGNET_ROOT"
OUTPUT_DIR="YOUR_OUTPUT_DIR"
LOG_DIR="YOUR_LOG_DIR"
Run the data preparation script to compute latent representations:
IMAGE_SIZE=256 COMPUTE_LATENT=True bash ./scripts/prepare_data.sh
Parameters:
IMAGE_SIZE
: Image size for processing (256, 512, or 1024). Latent sizes will be 32x32, 64x64, or 128x128 respectively.COMPUTE_LATENT
: Whether to compute and save the latent dataset (True/False)COMPUTE_FID
: Whether to compute FID statistics (True/False)
The script will:
- Encode ImageNet images to latent representations using a VAE model
- Save the latent dataset to
OUTPUT_DIR/
- Compute FID statistics and save to
OUTPUT_DIR/imagenet_{IMAGE_SIZE}_fid_stats.npz
- Log progress to
LOG_DIR/$USER/
After data preparation, you need to configure your FID cache reference in the config files:
Edit your config file (e.g., configs/run_b4.yml
) and replace the placeholder values:
dataset:
root: YOUR_DATA_ROOT # Path to your prepared latent dataset
fid:
cache_ref: YOUR_FID_CACHE_REF # Path to your FID statistics file
configs/run_b4.yml
- Configuration for MF-B/4 model training (recommended)configs/default.py
- Default configuration (Python format, used as base)
Configuration Hierarchy:
The system uses a hierarchical approach where run_b4.yml
overrides specific parameters from default.py
. This allows you to customize only the parameters you need while keeping sensible defaults.
Make sure to update both the dataset root path and the FID cache reference path according to your data preparation output.
Run the following commands to launch training:
bash scripts/launch.sh JOB_NAME
Note: Update the environment variables in scripts/launch.sh
before running:
DATA_ROOT
: Path to your prepared data directoryLOG_DIR
: Path where to save training logs
The training system uses two config files:
configs/default.py
- Base configuration with all default hyperparametersconfigs/run_b4.yml
- Model-specific overrides for MF-B/4 training
The system merges these files, allowing you to customize only the parameters you need.
To create a custom experiment:
- Create a new config file (e.g.,
configs/my_exp.yml
) - Update the launch script to use your config:
# In launch.sh, change the config line to: --config=configs/load_config.py:my_exp
Example custom config:
training:
num_epochs: 80 # Train for fewer epochs
method:
guidance_eq: 'none' # Disable guidance
During training, the code log training metrics to LOG_DIR/$USER/$JOBNAME/
. You can use tensorboard
to monitor the training progress.
tensorboard --logdir LOG_DIR --port 12666
The table below shows the generative performance under the model size of MF-B/4.
Settings | FID@80ep | FID@240ep |
---|---|---|
guidance_eq=none
|
61.09/60.75 | 48.16 |
guidance_eq=cfg , |
20.15/20.24 | 13.74 |
guidance_eq=cfg , |
19.15/18.70 | 11.35 |
Note: Numbers in FID@80ep are in format "reported in paper / this repo".
The 2nd and 3nd row correspond to Table 1. (f) and Table 5, using the same effective guidance scale as
- Dependencies and sanity check for JAX+GPU. (See this PR.)
- Pytorch code for CIFAR-10.
This repo is under the MIT license. See LICENSE for details.
@article{meanflow,
title={Mean Flows for One-step Generative Modeling},
author={Geng, Zhengyang and Deng, Mingyang and Bai, Xingjian and Kolter, J Zico and He, Kaiming},
journal={arXiv preprint arXiv:2505.13447},
year={2025}
}
This repository is a collaborative effort by Kaiming He, Runqian Wang, Qiao Sun, Zhicheng Jiang, Hanhong Zhao, Yiyang Lu, Xianbang Wang, and Zhengyang Geng, developed in support of several research projects.
We gratefully acknowledge the Google TPU Research Cloud (TRC) for granting TPU access. We hope this work will serve as a useful resource for the open-source community.
- Our MeanFlow Pytorch repo with CIFAR experiments.
- zhuyu-cs/MeanFlow: Pytorch training code with reproduced ImageNet results.
- pkulwj1994/easy_meanflow): Pytorch implementation with DDP+JVP and metrics for CIFAR-10.
- HaoyiZhu/MeanFlow-PyTorch: Pytorch implementation with ImageNet training code.
- haidog-yaqub/MeanFlow: Pytorch code for MNIST and CIFAR-10.
- noamelata/MeanFlow: Pytorch code for ImageNet.