- Full documentation: See the cebmf-torch documentation
- Example notebooks: See the
examples/directory for Jupyter notebooks demonstrating typical workflows.- To run the example notebooks, first add some additional dependencies with
uv sync --group examples(orpip install ".[examples]"if usingpip).
- To run the example notebooks, first add some additional dependencies with
cebmf_torch is a pure-PyTorch implementation of Empirical Bayes Matrix Factorization (EBMF) and Empirical Bayes Normal Means (EBNM) methods. It is designed for scalable, GPU-accelerated analysis of large datasets, with a focus on genomics and other high-dimensional applications. The package provides flexible prior families, efficient mini-batch EM, and full support for GPU computation.
- GPU-accelerated: All core computations are performed in PyTorch.
- Flexible priors: Easily extendable to new prior families.
- Mini-batch EM: Fast optimization for large datasets.
- Posterior inference: Compute posterior means and variances for all supported models.
- Empirical Bayes Matrix Factorization (EBMF) with flexible priors
- Empirical Bayes Normal Means (EBNM) solvers (normal, exponential, Laplace, point-mass, etc.)
- GPU support for all operations
- Mini-batch EM and Adam optimizers for mixture weights
- Analytical truncated normal moments for exponential prior
- Easy-to-use API for both beginners and advanced users
Installation is managed with uv, a fast Python package manager.
# Clone the repository
git clone https://github.com/william-denault/cebmf_torch.git
cd cebmf_torch
# Install the package and dependencies
uv sync
# Run tests to verify your installation
uv run pytestIf you wish to not use uv for some reason, then it is also possible to pip
install the package by replacing uv sync with pip install ..
Use the public docker image
docker pull ghcr.io/william-denault/cebmf_torch:latestor clone the repo and build the image yourself
docker build .
The Docker image includes:
- CUDA 13.0.1 runtime for GPU acceleration
- Python 3.12 with all dependencies
- Development tools (pytest, etc.)
Here's how to get started with the main functions:
import torch
from cebmf_torch import ash, cEBMF
# Example: ash with normal mixture prior
n = 10000
betahat = torch.randn(n, device='cuda' if torch.cuda.is_available() else 'cpu')
se = torch.full((n,), 0.5, device=betahat.device)
res = ash(betahat, se, prior='norm', batch_size=8192)
print(res.pi0, res.scale)
# Example: EBMF on a small matrix
Y = torch.randn(500, 200, device=betahat.device)
model = cEBMF(Y, K=5, prior_L='norm', prior_F='norm')
fit = model.fit(maxit=10)
print(fit.L.shape, fit.F.shape, fit.tau.item())- All computations run on the device of your input tensors (CPU or GPU).
- Mini-batch EM for
piis implemented via Adam on logits (recommended) or online EM. - Truncated normal moments are computed analytically in torch (no SciPy required).
- The codebase is modular and easy to extend for new prior families or custom models.
Contributions, bug reports, and feature requests are welcome! Please open an issue or pull request on GitHub.
For questions or help, open an issue or contact the maintainer.