A Path to the Variational Diffusion Loss
Alexander A. Alemi.
Diffusion models have made quite a splash, especially after the open-source release of Stable Diffusion. What are diffusion models, where does the loss come from and what does a simple example look like? I've recently helped open-source a simple, pedagogical, self-contained example colab of a diffusion model trained on EMNIST, which you can find as part of the Variational Diffusion Models (VDM) github page. In this post, I wanted to give some more background and a simple way to motivate where the loss function comes from.
Non-negativity of KL
Let's say we want to build a latent-variable model,
We can derive the tractable objective used to train these models using the observation that the KL2 divergence is non-negative and monotonic. The Kullback-Leibler (KL) divergence between any two distributions is non-negative:3
If we marginalize out some subset of random variables the KL divergence of the marginal distributions has to be less. For any two random variables:
VAEs
Imagine designing these joint distributions to have different flavors. Think of
Based on the properties of the KL divergence, these two joint distributions must have a non-negative KL divergence that is monotonic to marginalizing out one of the variables:
So, by minimizing the KL between our forward and reverse process -- by aligning the two joint distributions -- we can ensure that we make progress towards learning a good generative model of our images
The tightness of this bound is controlled by how close together the remaining conditional distributions are:
So, again, all we started with is the idea of two different processes, the forward process that takes images and encodes them and a reverse process that samples some latents from a known distribution and decodes them. If we try to minimize the KL divergence between these two processes, forward to reverse, we can ensure that this is a valid bound on the marginal KL between the true image distribution
At the end of the day, the hope and the dream we seem to have in doing latent variable modeling is that maybe we will somehow be more successful in learning a reverse
Diffusion
For diffusion models, honestly, there isn't much to add except they add many more steps. The only difference is that instead of a two-step forward process, in diffusion we imagine a many-stepped (or potentially continuous) forward and reverse process.
In particular, in most diffusion models we fix the forward process to be a Markov chain:
This takes an ordinary image and then adds more and more noise to it until it looks more or less indistinguishable from just isotropic Gaussian noise.5
One particularly nice thing about using Gaussians for every step of the forward process here is that the composition of a bunch of conditional Gaussians is itself Gaussian so we will have a closed form for the marginal distribution at any intermediate time:
With a forward process defined, we parameterize or learn the reverse process, a Markov chain that operates in the opposite direction:
The VDM loss is6 simply the KL between these two joints, which serves as an upper bound on the KL of the image marginals:
Just as in the case of a VAE, here, the hope is that it might actually be easier to model the larger joint distribution than it was to try to model the density directly. In the case of simple diffusion models, the forward process is fixed additive Gaussian noise. If we make enough steps in the forward process we believe we ought to be able to learn the reverse process exactly.7
Various Sundry Tricks
The joint KL is equivalent to the VDM loss. However, in practice, to make this loss efficient to train, diffusion models leverage a lot of the known structure
of the forward process to power a very clever parameterization of the reverse process. This requires some tricky rearranging of terms and some stochastic approximation to make the whole thing efficient.
To see the code, please check out the example colab as well as its accompanying text that walks through some of these details in more detail.
To utilize our knowledge of the forward process, we're actually going to rewrite the forward process not as a sequence of conditional Gaussian steps (a bottom-up forward process):
We'll then parameterize our reverse process
The actual parametric model in a diffusion model is this bit,
With these choices in place, we can now look at the full joint KL and organize terms.
The last trick we're going to use is that we're going to avoid computing all of the terms in our sum by simply not computing all of the terms in our sum. We'll approximate the sum with Monte Carlo: we'll simply randomly choose one of the terms and upweight it appropriately. At that point, we have the loss function used to train VDM models. A very nice thing about the VDM loss is that it is clear that we are optimizing a bound on the marginal likelihood of our generative model. As you can learn in the VDM Paper, many of the diffusion models you've heard about correspond to a weighted form of this same objective, where different terms in the sum get different weights.
After going through all of the fancy math, the analytic KL divergences involved in the diffusion loss simplify quite nicely:
Closing Thoughts
So, why are diffusion models so interesting? Well, first and foremost, the reason they are drawing so much attention is that they have shown tremendous performance. It feels like for the first time we have models that are able to generate very high resolution, very high fidelity natural images. Projects like DALL-E2, Imagen, and Stable Diffusion show really impressive results. What is the magic driving these models?
At a high level, I think we can say that diffusion models start to realize the dream of latent variable models. Sometimes, when you are faced with a problem that is too difficult, you can crack it if you consider an even harder, related problem. As I tried to demonstrate here, even for simple latent variable models like VAEs and especially for diffusion models, one reason we can point to for their success is that instead of directly modeling the distribution over images, they model a much larger joint distribution. That larger joint distribution is strictly speaking a bigger thing to attempt to model, but here we get to design the forward process in such a way that even if there are many pieces to the forward process, those pieces individually are easier to tackle.
However, if that were the case, shouldn't we have expected deep hierarchical models to perform similarly awesomely? Probably, though here I think there is another real trick that diffusion has up its sleeve. For a general deep hierarchical generative model, even if by splitting the problem up into smaller pieces you might have split it up into easier-to-model tasks, to evaluate the joint KL you still need to evaluate all of those terms. That is, as your model becomes richer and more computationally expressive because of its depth, so does the cost of training your model, as you have to evaluate all of the layers at each step in the training process.
Diffusion models avoid this by structuring their forward process in such a way that all of the steps share a great deal of structural similarity. This allows diffusion to approximate a sum of a potentially large number of steps by a single randomly chosen step. If each step looks more or less the same, you can get a good estimate for the whole sum by looking at an individual, random, term.
The last trick up its sleeve is, even if you managed to design a deep hierarchical generative model with this structural homogeneity property, if you wanted to get to some intermediate position in the hierarchy you'd still have to run roughly half of the full forward process. That would still be expensive in general. Here, diffusion avoids that entirely.
As boring as a sequence of conditional Gaussians is as a forward process, it is also beautiful: it enables exact analytic marginalization to intermediate steps. You can very quickly mimic the result of adding hundreds of steps of additive Gaussian noise by simply adding a moderate amount of Gaussian noise in a single shot.
So, ultimately, what do I think is one of the main reasons diffusion models do so well? I think it's because they can do so well! I think it's because they are very powerful, expressive, generative models. Sampling from them is generally rather expensive. Drawing a sample means running the full reverse process, which might mean calling the central score net a thousand or so times. That is a very powerful and very expressive generative model, but magically, we can train that generative model's likelihood without ever having to actually instantiate the full generative process at training time due to our set of sundry tricks.
I'm excited to see where this all goes and hope this post and the colab help to introduce these magical models to a wider audience.
Special thanks to Ben Poole, Pavel Izmailov, Christopher Suter, and Sergey Ioffe, and Ian Fischer for helpful feedback on this post.