This is the PyTorch re-implementation for the paper Mean Flows for One-step Generative Modeling. This code is based on the official JAX implementation and REPA.
conda create -n meanflow python=3.10 -y
conda activate meanflow
pip install -r requirements.txt
You can generate images (and the .npz file can be used for ADM evaluation suite) through the following script:
torchrun --nnodes=1 --nproc_per_node=8 generate_meanflow.py \
--model SiT-B/4 \
--num-fid-samples 50000 \
--ckpt [YOUR_CHECKPOINT_PATH] \
--per-proc-batch-size=64 \
--vae=ema \
--num-steps=1
The official repository provides a JAX checkpoint for SiT-B/4
. I have converted it into a PyTorch checkpoint, which you can download here. You can set [YOUR_CHECKPOINT_PATH]
to the path of the downloaded meanflow-B4.pth
and evaluate this checkpoint with the command above.
After obtaining the .npz
result, you may first create a new conda environment to avoid conflicts following ADM evaluation and download its VIRTUAL_imagenet256_labeled.npz
. Then you could run the following command to get the metrics:
# in your new eval enviornment
python evaluator.py [YOUR_PATH_TO_VIRTUAL_imagenet256_labeled.npz] [YOUR_RESULT_npz]
Currently, we provide experiments for ImageNet. You can place the data that you want and can specifiy it via --data-dir
arguments in training scripts. Please refer to our preprocessing guide.
Example command:
accelerate launch train_meanflow.py \
--report-to="wandb" \
--allow-tf32 \
--mixed-precision="bf16" \
--seed=0 \
--model="SiT-XL/2" \
--proj-coeff=0.0 \
--encoder-depth=0 \
--output-dir="exps" \
--exp-name="meanflow-sitxl" \
--batch-size=256 \
--adam-beta2 0.95 \
--epochs 240 \
--gradient-accumulation-steps 2 \
--t-start 0.0 \
--t-end 0.75 \
--omega 0.2 \
--kappa 0.92 \
--data-dir=[YOUR_DATA_PATH]
Then this script will automatically create the folder in exps
to save logs and checkpoints. You can adjust the following options:
--models
:[SiT-B/4, SiT-B/2, SiT-L/2, SiT-XL/2]
--output-dir
: Any directory that you want to save checkpoints and logs--exp-name
: Any string name (the folder will be created underoutput-dir
)
Warning: This repository is forked from REPA, and I keep some REPA options (such as
proj-coeff
andencoder-depth
). However, they are actually not implemented and not supported yet. Just always disable them.
Note: The
batch-size
option specifies the global batch size distributed across all devices. The actual local batch size on each GPU is calculated asbatch-size // num-gpus // gradient-accumulation-steps
.
-
I have made this repository executable, but I have not yet trained and evaluated it with the exact settings from the original paper to see if the performance matches. If you find any mismatches or implementation errors, or if you use this repository to reproduce the original paper's results, feel free to let me know!
-
Due to the incompatibility between the Jacobian-vector product (jvp) operation and FlashAttention, the
fused_attn
flag should always be disabled for training. For evaluation, the flag can be enabled.
This code is mainly built upon REPA and the official JAX implementation of MeanFlow.
@article{meanflow,
title={Mean Flows for One-step Generative Modeling},
author={Geng, Zhengyang and Deng, Mingyang and Bai, Xingjian and Kolter, J Zico and He, Kaiming},
journal={arXiv preprint arXiv:2505.13447},
year={2025}
}