This repository contains some JAX example code for a talk about Dirichlet Flow Matching, an approach to "discrete diffusion" over categorical sequences.
I found it easiest to play around with this repository with pip
installing the package in editable mode:
git clone https://github.com/ElisR/DirichletFlowMatching.git
cd DirichletFlowMatching
pip install -e .
Below are some notebooks that interactively demonstrate some concepts in the paper.
Note
Unfortunately, GitHub's LaTeX parser is slightly limited, and will aggressively interpret subscript indicators as attempts to italicise text, so I will be using superscript more than I would like.
Warning
This talk is phrased in a way that assumes some familiarity with diffusion models.
Generative models allow sampling
As we all know, generative diffusion models have had considerable success at this task (thinking specifically of image generation models). In its original form, a diffusion model acts on continuous data, where the noising process is easy to interpret.
However, an extension of diffusion models to discrete data isn't quite so obvious. The allure of discrete diffusion is obvious, however, even if we just restrict our imagination to (protein) language modelling tasks. For example, one of the downsides of autoregressive models is that inference takes an amount of time proportional to sequence length. Producing one word at a time is contrary to how we might perceive our own sentences as forming, guided by intention, in our heads before being spoken, or how we would sketch the outline of a document before filling in the details. As a final motivator: the already-successful masked manguage modelling (MLM) objective (where ~15% of tokens are masked) looks like a one-step denoising process, so what happens if we go further?
With a discrete diffusion model, we could feed our model many token sequences during training (for example amino acid sequences of viable proteins), and produce novel sequences at inference time, hopefully from the same distribution.
Here, we will review Dirichlet Flow Matching (DFM) as a new approach to this problem.
Here is a "TL;DR" the current approaches to making the idea of diffusion models (i.e. noising the data and learning the gradient of the probability distribution) work for discrete data. We can come back to this at the end of the talk.
Perform something that looks like diffusion on the categorical probabilities.
Noise the data as a Markov chain acting on the categorical distribution.
Map the discrete variables into continuous embeddings, then do standard Gaussian diffusion in that space before mapping back.
🔀 Flow Matching (Lipman et al. 2022)
Flow matching provides a training objective similar to those from diffusion models, but applies it to the (continuous) normalising flows of yesteryear.
At a high level, the neural network again learns the small steps needed to incrementally go from a pure noise distribution
On both ends one will have noisy samples
The first problem to overcome is that we don't have
One also assumes knowledge of a conditional vector field
We will construct the target marginal probability path
The next leap is that the marginal vector field that generates
Lipman et al. then show that minimising
This is great!
It lets us train a model to produce samples from
A simplex
DFM goes a step further and relax their
During flow matching, this means that the transport destination will be samples from the vertices, but at intermediate times the samples can lie anywhere on the simplex, like a superposition of different valid destinations.
One other modification by DFM is that instead of training their neural network to predict a vector field
This way, at any point in time, the model is trying to guess the correct label of a variable. This may remind you of how diffusion model objectives are often recast to predicting fully denoised samples at all times.
At inference time, the vector field can be parameterised as
A key ingredient was missing from our introduction to flow matching: how does one construct
Before getting started, let's specify the noisy prior distribution to be the uniform density on the simplex i.e. a Dirichlet distribution3 with uniform prior:
Typically, flow matching papers produce the conditional vector field
The simplest interpolant is just the linear flow map
As DFM points out, however, this design has some pathological behaviour.
Looking back at the modified training objective
Now look again at the linear flow map above, which moves samples at a constant velocity.
We know that at times
Instead of starting from an interpolant, DFM opts for defining
From this, they derive a valid
A wonderful feature that this flow matching implementation retains from diffusion is the capability to do guidance, both with and without a classifier.
Guidance is the ability to generate from a data distribution of a specific class
Following the recipe of classifier-free guidance, if we have class-conditional and unconditional flow models
In classifier guidance, one derives a conditional score function from the gradients of a noisy classifier (derived from Bayes' theorem):
This isn't quite so straightforward, however, because classifier gradients may have components that are off-simplex or lead to invalid negative posterior probabilities. In their results section, they find that classifier-free guidance performs better than classifier guidance. So, we can justify forgetting the details about how to make classifier guidance work here, but I briefly summarise it below.
One has to project the score onto the tangent plane of the simplex by replacing
Distillation aims to reduce the inference time of the iterative generative process while retaining sample quality. This usually involves a "teacher" model and a "student" model which aims to do more work with the same amount of compute. For example, Salimans & Ho's progressive distillation trains a student to perform two of the teacher's denoising steps in one step, then repeat this until inference takes only four steps rather than thousands.
No such distillation techniques exist discrete diffusion or autoregressive language models,
However, this paper has mapped inference to solving a deterministic ODE based on the vector field
A quick word about extending this to sequences: it may look odd that we have just been dicussing single categorical variables here, but in practice extending this to modelling sequences of categorical variables like protein sequences is trivial.
Sequences just live in a product space of simplexes
Their results are evaluated on conditional generation tasks and an unconditional generation task.
Their conditional task aims to produce a DNA sequence with 1024 base pairs having a given transcription/promoter profile.
The target profile comes from a prediction from another model about a sequence in the dataset.
Their metric is mean squared error between the predicted profile of the generated sequence with the profile of the ground truth sequence that was originally used as a target.
DFM performs better than the other published D3PM, Bit Diffusion, DDSM and language model baselines.
The language model was previously the best with 0.0333 MSE, but took 1024 model evaluations.
However, even their distilled DFM model with one step outperforms this, getting 0.0278 MSE!
Interestingly, linear flow matching which they introduce and criticise is not far behind, probably because base pairs have a small vocabulary of
Their second task aims to generate DNA enhancer sequences. DFM again performs better on a Frechet-Inception-Distance-type metric modified for a classifier that predicts cell types. Their distilled model also retains most of the performance at 100x less cost.
Like in other image diffusion models, they find that guidance can even improve unconditional generation. This involves picking a class based on empirical frequency, then guiding towards that class during generation.
Footnotes
-
If the transport/continuity equation is unfamiliar, it's just differential equation that expresses that a certain quantity must be "conserved". For example, given a snapshot of a fluid, whatever the density distribution and velocity field describing the motion of small fluid parcels, we know for certain that the mass comprising the fluid cannot be created or destroyed, which restricts how the density can evolve. ↩
-
The proof of this is short and just shows the given vector field satisfying the transport equation by taking $\partial^t p^t(\mathbf{x}) = \int [\partial^t p^t(\mathbf{x} | \mathbf{x}^1)] p^{\text{data}}(\mathbf{x}^1) d\mathbf{x}^1$ and substituting in the transport equation for the conditional vector field, with some switching of integrals and derivatives. ↩
-
The Dirichlet distribution is the conjugate prior for a multinomial distribution in Bayesian statistics, meaning that if we started with Dirichlet prior over the class probabilities, the posterior distribution over class probabilities after observing samples drawn from a multinomial distribution will also be a Dirichlet distribution with modified parameters. $\boldsymbol{\alpha}$ can be interpreted as the number of prior observations of each class. A prior of $\boldsymbol{\alpha} = (1, \ldots, 1)^T$ therefore means you pretend that you have seen every class once when drawing from a multinomial. ↩