$$ \newcommand{\bone}{\mathbf{1}} \newcommand{\bbeta}{\mathbf{\beta}} \newcommand{\bdelta}{\mathbf{\delta}} \newcommand{\bepsilon}{\mathbf{\epsilon}} \newcommand{\blambda}{\mathbf{\lambda}} \newcommand{\bomega}{\mathbf{\omega}} \newcommand{\bpi}{\mathbf{\pi}} \newcommand{\bphi}{\mathbf{\phi}} \newcommand{\bvphi}{\mathbf{\varphi}} \newcommand{\bpsi}{\mathbf{\psi}} \newcommand{\bsigma}{\mathbf{\sigma}} \newcommand{\btheta}{\mathbf{\theta}} \newcommand{\btau}{\mathbf{\tau}} \newcommand{\ba}{\mathbf{a}} \newcommand{\bb}{\mathbf{b}} \newcommand{\bc}{\mathbf{c}} \newcommand{\bd}{\mathbf{d}} \newcommand{\be}{\mathbf{e}} \newcommand{\boldf}{\mathbf{f}} \newcommand{\bg}{\mathbf{g}} \newcommand{\bh}{\mathbf{h}} \newcommand{\bi}{\mathbf{i}} \newcommand{\bj}{\mathbf{j}} \newcommand{\bk}{\mathbf{k}} \newcommand{\bell}{\mathbf{\ell}} \newcommand{\bm}{\mathbf{m}} \newcommand{\bn}{\mathbf{n}} \newcommand{\bo}{\mathbf{o}} \newcommand{\bp}{\mathbf{p}} \newcommand{\bq}{\mathbf{q}} \newcommand{\br}{\mathbf{r}} \newcommand{\bs}{\mathbf{s}} \newcommand{\bt}{\mathbf{t}} \newcommand{\bu}{\mathbf{u}} \newcommand{\bv}{\mathbf{v}} \newcommand{\bw}{\mathbf{w}} \newcommand{\bx}{\mathbf{x}} \newcommand{\by}{\mathbf{y}} \newcommand{\bz}{\mathbf{z}} \newcommand{\bA}{\mathbf{A}} \newcommand{\bB}{\mathbf{B}} \newcommand{\bC}{\mathbf{C}} \newcommand{\bD}{\mathbf{D}} \newcommand{\bE}{\mathbf{E}} \newcommand{\bF}{\mathbf{F}} \newcommand{\bG}{\mathbf{G}} \newcommand{\bH}{\mathbf{H}} \newcommand{\bI}{\mathbf{I}} \newcommand{\bJ}{\mathbf{J}} \newcommand{\bK}{\mathbf{K}} \newcommand{\bL}{\mathbf{L}} \newcommand{\bM}{\mathbf{M}} \newcommand{\bN}{\mathbf{N}} \newcommand{\bP}{\mathbf{P}} \newcommand{\bQ}{\mathbf{Q}} \newcommand{\bR}{\mathbf{R}} \newcommand{\bS}{\mathbf{S}} \newcommand{\bT}{\mathbf{T}} \newcommand{\bU}{\mathbf{U}} \newcommand{\bV}{\mathbf{V}} \newcommand{\bW}{\mathbf{W}} \newcommand{\bX}{\mathbf{X}} \newcommand{\bY}{\mathbf{Y}} \newcommand{\bZ}{\mathbf{Z}} \newcommand{\bsa}{\boldsymbol{a}} \newcommand{\bsb}{\boldsymbol{b}} \newcommand{\bsc}{\boldsymbol{c}} \newcommand{\bsd}{\boldsymbol{d}} \newcommand{\bse}{\boldsymbol{e}} \newcommand{\bsoldf}{\boldsymbol{f}} \newcommand{\bsg}{\boldsymbol{g}} \newcommand{\bsh}{\boldsymbol{h}} \newcommand{\bsi}{\boldsymbol{i}} \newcommand{\bsj}{\boldsymbol{j}} \newcommand{\bsk}{\boldsymbol{k}} \newcommand{\bsell}{\boldsymbol{\ell}} \newcommand{\bsm}{\boldsymbol{m}} \newcommand{\bsn}{\boldsymbol{n}} \newcommand{\bso}{\boldsymbol{o}} \newcommand{\bsp}{\boldsymbol{p}} \newcommand{\bsq}{\boldsymbol{q}} \newcommand{\bsr}{\boldsymbol{r}} \newcommand{\bss}{\boldsymbol{s}} \newcommand{\bst}{\boldsymbol{t}} \newcommand{\bsu}{\boldsymbol{u}} \newcommand{\bsv}{\boldsymbol{v}} \newcommand{\bsw}{\boldsymbol{w}} \newcommand{\bsx}{\boldsymbol{x}} \newcommand{\bsy}{\boldsymbol{y}} \newcommand{\bsz}{\boldsymbol{z}} \newcommand{\bsA}{\boldsymbol{A}} \newcommand{\bsB}{\boldsymbol{B}} \newcommand{\bsC}{\boldsymbol{C}} \newcommand{\bsD}{\boldsymbol{D}} \newcommand{\bsE}{\boldsymbol{E}} \newcommand{\bsF}{\boldsymbol{F}} \newcommand{\bsG}{\boldsymbol{G}} \newcommand{\bsH}{\boldsymbol{H}} \newcommand{\bsI}{\boldsymbol{I}} \newcommand{\bsJ}{\boldsymbol{J}} \newcommand{\bsK}{\boldsymbol{K}} \newcommand{\bsL}{\boldsymbol{L}} \newcommand{\bsM}{\boldsymbol{M}} \newcommand{\bsN}{\boldsymbol{N}} \newcommand{\bsP}{\boldsymbol{P}} \newcommand{\bsQ}{\boldsymbol{Q}} \newcommand{\bsR}{\boldsymbol{R}} \newcommand{\bsS}{\boldsymbol{S}} \newcommand{\bsT}{\boldsymbol{T}} \newcommand{\bsU}{\boldsymbol{U}} \newcommand{\bsV}{\boldsymbol{V}} \newcommand{\bsW}{\boldsymbol{W}} \newcommand{\bsX}{\boldsymbol{X}} \newcommand{\bsY}{\boldsymbol{Y}} \newcommand{\bsZ}{\boldsymbol{Z}} \newcommand{\calA}{\mathcal{A}} \newcommand{\calB}{\mathcal{B}} \newcommand{\calC}{\mathcal{C}} \newcommand{\calD}{\mathcal{D}} \newcommand{\calE}{\mathcal{E}} \newcommand{\calF}{\mathcal{F}} \newcommand{\calG}{\mathcal{G}} \newcommand{\calH}{\mathcal{H}} \newcommand{\calI}{\mathcal{I}} \newcommand{\calJ}{\mathcal{J}} \newcommand{\calK}{\mathcal{K}} \newcommand{\calL}{\mathcal{L}} \newcommand{\calM}{\mathcal{M}} \newcommand{\calN}{\mathcal{N}} \newcommand{\calO}{\mathcal{O}} \newcommand{\calP}{\mathcal{P}} \newcommand{\calQ}{\mathcal{Q}} \newcommand{\calR}{\mathcal{R}} \newcommand{\calS}{\mathcal{S}} \newcommand{\calT}{\mathcal{T}} \newcommand{\calU}{\mathcal{U}} \newcommand{\calV}{\mathcal{V}} \newcommand{\calW}{\mathcal{W}} \newcommand{\calX}{\mathcal{X}} \newcommand{\calY}{\mathcal{Y}} \newcommand{\calZ}{\mathcal{Z}} \newcommand{\R}{\mathbb{R}} \newcommand{\C}{\mathbb{C}} \newcommand{\N}{\mathbb{N}} \newcommand{\Z}{\mathbb{Z}} \newcommand{\F}{\mathbb{F}} \newcommand{\Q}{\mathbb{Q}} \DeclareMathOperator*{\argmax}{arg\,max} \DeclareMathOperator*{\argmin}{arg\,min} \newcommand{\nnz}[1]{\mbox{nnz}(#1)} \newcommand{\dotprod}[2]{\langle #1, #2 \rangle} \newcommand{\ignore}[1]{} \let\Pr\relax \DeclareMathOperator*{\Pr}{\mathbf{Pr}} \newcommand{\E}{\mathbb{E}} \DeclareMathOperator*{\Ex}{\mathbf{E}} \DeclareMathOperator*{\Var}{\mathbf{Var}} \DeclareMathOperator*{\Cov}{\mathbf{Cov}} \DeclareMathOperator*{\stddev}{\mathbf{stddev}} \DeclareMathOperator*{\avg}{avg} \DeclareMathOperator{\poly}{poly} \DeclareMathOperator{\polylog}{polylog} \DeclareMathOperator{\size}{size} \DeclareMathOperator{\sgn}{sgn} \DeclareMathOperator{\dist}{dist} \DeclareMathOperator{\vol}{vol} \DeclareMathOperator{\spn}{span} \DeclareMathOperator{\supp}{supp} \DeclareMathOperator{\tr}{tr} \DeclareMathOperator{\Tr}{Tr} \DeclareMathOperator{\codim}{codim} \DeclareMathOperator{\diag}{diag} \newcommand{\PTIME}{\mathsf{P}} \newcommand{\LOGSPACE}{\mathsf{L}} \newcommand{\ZPP}{\mathsf{ZPP}} \newcommand{\RP}{\mathsf{RP}} \newcommand{\BPP}{\mathsf{BPP}} \newcommand{\P}{\mathsf{P}} \newcommand{\NP}{\mathsf{NP}} \newcommand{\TC}{\mathsf{TC}} \newcommand{\AC}{\mathsf{AC}} \newcommand{\SC}{\mathsf{SC}} \newcommand{\SZK}{\mathsf{SZK}} \newcommand{\AM}{\mathsf{AM}} \newcommand{\IP}{\mathsf{IP}} \newcommand{\PSPACE}{\mathsf{PSPACE}} \newcommand{\EXP}{\mathsf{EXP}} \newcommand{\MIP}{\mathsf{MIP}} \newcommand{\NEXP}{\mathsf{NEXP}} \newcommand{\BQP}{\mathsf{BQP}} \newcommand{\distP}{\mathsf{dist\textbf{P}}} \newcommand{\distNP}{\mathsf{dist\textbf{NP}}} \newcommand{\eps}{\epsilon} \newcommand{\lam}{\lambda} \newcommand{\dleta}{\delta} \newcommand{\simga}{\sigma} \newcommand{\vphi}{\varphi} \newcommand{\la}{\langle} \newcommand{\ra}{\rangle} \newcommand{\wt}[1]{\widetilde{#1}} \newcommand{\wh}[1]{\widehat{#1}} \newcommand{\ol}[1]{\overline{#1}} \newcommand{\ul}[1]{\underline{#1}} \newcommand{\ot}{\otimes} \newcommand{\zo}{\{0,1\}} \newcommand{\co}{:} %\newcommand{\co}{\colon} \newcommand{\bdry}{\partial} \newcommand{\grad}{\nabla} \newcommand{\transp}{^\intercal} \newcommand{\inv}{^{-1}} \newcommand{\symmdiff}{\triangle} \newcommand{\symdiff}{\symmdiff} \newcommand{\half}{\tfrac{1}{2}} \newcommand{\mathbbm}{\Bbb} \newcommand{\bbone}{\mathbbm 1} \newcommand{\Id}{\bbone} \newcommand{\SAT}{\mathsf{SAT}} \newcommand{\bcalG}{\boldsymbol{\calG}} \newcommand{\calbG}{\bcalG} \newcommand{\bcalX}{\boldsymbol{\calX}} \newcommand{\calbX}{\bcalX} \newcommand{\bcalY}{\boldsymbol{\calY}} \newcommand{\calbY}{\bcalY} \newcommand{\bcalZ}{\boldsymbol{\calZ}} \newcommand{\calbZ}{\bcalZ} $$

An analytic exploration of VAEs

post.cover
Boukhari, Aymen. Variational Autoencoders (VAEs). Header image, Medium, 22 Oct. 2024, https://medium.com/@aymne011/variational-autoencoders-vaes-24f5da384a9d.

I’ve figured out that I’m able to retain concepts much better if I summarize them (after a short period). That said, I don’t think memorizing the nitty-gritty formulas and math is what’s helpful. Instead, I believe it’s more helpful to remember the whys and hows. Given that I want this to become more of a habit, I decided I may as well make posts covering things I’d like to keep ingrained. As the title suggests, my first post of this series will be on VAEs.

A little background on autoencoders

Figure 1. Grammarly. Autoencoders in Deep Learning. Header image, Grammarly, 12 Feb. 2025, https://contenthub-static.grammarly.com/blog/wp-content/uploads/2024/10/Autoencoders-in-Deep-Learning-760x400.png.

Understanding autoencoders is helpful for learning about variational autoencoders (VAEs). There are two main components to an autoencoder: the encoder and decoder.

The encoder network is the part of the model that compresses inputs into a lower-dimensional representation. The intuition behind this is that lower dimensionality encourages the model to learn the most important features that capture key parts of the data. The “bottleneck” representation that sits between the encoder and decoder is called the latent space.

The decoder network is the part of the model that tries to recreate the original inputs from the latent space features.

In practice, the encoder need not shrink at every layer (nor expand monotonically in the decoder), and the latent space dimension may be equal to or even exceed the input dimension, depending on regularization and architectural choices.

A probabilistic spin on autoencoders

VAEs are different from traditional autoencoders because they take a probabilistic approach, namely, they learn a distribution over the latent space instead of a fixed mapping. The significance of this is that sampling from the distribution allows VAEs to generate new data instead of just reconstructing the original data. So, VAEs are known as a generative model. While VAEs are not that different from the original autoencoder in concept, a lot of considerations are made with the underlying structure.

Overview

Figure 2. Boukhari, Aymen. Variational Autoencoders (VAEs). Header image, Medium, 22 Oct. 2024, https://medium.com/@aymne011/variational-autoencoders-vaes-24f5da384a9d.

Like the traditional model, VAEs can also be broken into their encoder and decoder components.

However, unlike before, the encoder now outputs the parameters describing a probability distribution, which is often \(\mu_{z \mid x}\) and \(\sigma^2_{z \mid x}\) to describe a Gaussian distribution. To retrieve the vector $z$ from the latent space, we effectively sample $z$ from $\mathcal{N}(\mu, \mathrm{diag}(\sigma^2))$. We say effectively since we must apply the reparameterization trick so that the sampling step can be backpropagated; we’ll get more into this in a bit. This whole step is typically written as \(q_\phi(z \mid x)\) where $q$ is the encoder network, $z$ is a vector from the latent space, and $x$ is our input.

Once we have a sampled latent vector, we can use a probabilistic decoder to reconstruct the input $x’$. As with the encoder network, the new decoder network also outputs a distribution. Therefore, to get $x’$, we must again sample $x \mid z$ from \(\mathcal{N}(\mu_{x\mid z}, \mathrm{diag}(\sigma^2_{x\mid z}))\), which is shortened to $p_\theta(x \mid z)$.

Approximating the posterior density \(p_\theta(z \mid x)\)

You might be wondering, why do we require an encoder that approximates \(p_\theta(z \mid x)\)? The reason is that the data likelihood $p_\theta(x) = \int p_\theta(z)p_\theta(x \mid z) dz$ is intractable since that’d require computing the likelihood for every possible $z$ in a high-dimensional space! This has no closed-form solution. It follows that \(p_\theta(z \mid x) = \frac{p_\theta(x \mid z) p_\theta(z)}{p_\theta(x)}\) is also intractable.

Therefore, we introduce an encoder $q_\phi(z \mid x)$ to approximate the posterior density that is tractable to sample from and whose density we can evaluate.

Reparameterization Trick

If you look at Figure 2, you’ll notice we actually draw

\[z = \mu_{z \mid x} + \sigma_{z \mid x} \odot \epsilon\]

with $\epsilon$ sampled from a fixed standard normal $\mathcal{N}(0, 1)$ for training. You can check that this is exactly equivalent to sampling $z$ from a Gaussian whose mean is $\mu_{z\mid x}$ and covariance is \(\mathrm{diag}(\sigma^2_{z\mid x}).\) The important part of this is that all randomness is kept within $\epsilon$, whose distribution is independent of network parameters. This means the computation of $z$ is a purely deterministic function of $\mu_{z \mid x}$, $\sigma_{z \mid x}$, and $\epsilon$. So, the entire sampling step becomes differentiable, and letting gradients flow through during backpropagation.

In practice, we draw $K$ samples from $\epsilon^{(k)} \sim \mathcal{N}(0, I)$ to form

\[\mathbb{E} _{z \sim q_\phi(z \mid x)} [\log p_\theta (x \mid z)] \approx \frac 1 K \sum_k \log p_\theta(x \mid z^{(k)})\]

and backpropagate through the average.

Loss Function

Since we have our encoder and decoder networks, let’s work on the data likelihood.

\[{ \small \begin{align*} & \log p_\theta\bigl(x^{(i)}\bigr) \\ &= \mathbb{E}_{z\sim q_\phi(z\mid x^{(i)})}\bigl[\log p_\theta(x^{(i)})\bigr] \quad \text{($p_\theta (x^{(i)})$ independent of $z$)} \\ &= \mathbb{E}_{z}\Bigl[\log\frac{p_\theta(x^{(i)}\mid z)\,p_\theta(z)}{p_\theta(z\mid x^{(i)})}\Bigr] \quad \text{(Bayes' Rule)} \\ &= \mathbb{E}_{z}\Bigl[\log\frac{p_\theta(x^{(i)}\mid z)\,p_\theta(z)}{p_\theta(z\mid x^{(i)})} \frac{q_\phi(z\mid x^{(i)})}{q_\phi(z\mid x^{(i)})} \Bigr] \quad \text{(Identity)} \\ &= \mathbb{E}_{z}\bigl[\log p_\theta(x^{(i)}\mid z)\bigr] - \mathbb{E}_{z}\Bigl[\log\frac{q_\phi(z\mid x^{(i)})}{p_\theta(z)}\Bigr] + \mathbb{E}_{z}\Bigl[\log\frac{q_\phi(z\mid x^{(i)})}{p_\theta(z\mid x^{(i)})}\Bigr] \quad \text{(Logs and Rearranging)} \\ &= \mathbb{E}_{z}\bigl[\log p_\theta(x^{(i)}\mid z)\bigr] \;-\; D_{\mathrm{KL}}\bigl(q_\phi(z\mid x^{(i)})\;\|\;p_\theta(z)\bigr) \;+\; D_{\mathrm{KL}}\bigl(q_\phi(z\mid x^{(i)})\;\|\;p_\theta(z\mid x^{(i)})\bigr) \end{align*} }\]

From the final expression, we use

\[\mathcal{L}\bigl(x^{(i)},\theta,\phi\bigr) = \E_{z\sim q_\phi(z\mid x^{(i)})}\bigl[\log p_\theta(x^{(i)}\mid z)\bigr] \;-\; D_{\mathrm{KL}}\bigl(q_\phi(z\mid x^{(i)})\;\|\;p(z)\bigr) \;\]

for training. We call this the evidence lower bound (ELBO). Notice that the last term of $\log p_\theta\bigl(x^{(i)}\bigr)$ was dropped because it’s intractable. However, KL-divergence is greater than or equal to $0$. Thus, we have a lower bound for the exact log likelihood, which we can take the gradient of and optimize! We call $\E_{z\sim q_\phi(z\mid x^{(i)})}\bigl[\log p_\theta(x^{(i)}\mid z)\bigr]$ the reconstruction loss because we are penalized for poor reconstructions of $x$ from $z$. Recall that this is what we approximated by drawing $K$ reparameterized samples $z^{(k)} = \mu + sigma \odot \epsilon^{(k)}$. The KL term $D_{\mathrm{KL}}\bigl(q_\phi(z\mid x^{(i)})\;|\;p(z)\bigr)$ is to make the approximate posterior distribution be closer to the prior.

Some reasons for the last statement include:

  • encouraging smoothness by making $q_\phi(z \mid x)$ fit to a broad Gaussian (usually), so small changes in $z$ and $x$ produce meaningful small changes
  • imposing structure by pulling posteriors away from isolated spikes into overlapping blobs, which prevents clumps for each data point
  • ensuring that random draws from the prior $p(z)$ come from regions the decoder knows how to map back to realistic $x$
  • limits the amount of information $z$ can carry about $x$ to improve generalization

Additional Note

In the standard VAE model, because we pick

\[q_\phi(z \mid x) = \mathcal{N}\bigl(z;\,\mu_\phi(x),\,\mathrm{diag}(\sigma_\phi(x)^2)\bigr), \quad p(z) = \mathcal{N}(z;\,0,\,I),\]

the KL-divergence has closed-form expression $\frac 1 2 \sum_j \left(\mu^2_j + \sigma^2_j - 1 - \ln \sigma^2_j\right)$ .

Last thoughts

Once training finishes, generation is trivial:

  1. Draw a latent $z$ from the same Gaussian that we used as our KL prior
  2. Pass $z$ through the trained decoder $p_\theta(x \mid z)$ to sample a new $x’$ or take its mean for a deterministic reconstruction

I originally encountered VAEs in the context of Gaussian Splatting. The encoder’s only job is to amortize the intractable posterior during training; at generation time, only the decoder is needed. Working with VAEs drove home to me how neural networks are simply highly flexible functions that we fit to data.




    Related Posts:

  • Some reasons on why OCaml and functional programming are great
  • Compiler Design Course Experience @ CMU