Variational autoencoder

This blog post contains a tutorial on the variational autoencoder (VAE). Since its introduction by [Kingma and Welling, 2014, Rezende et al., 2014], the framework has received much attention for its generative modeling and representation learning capabilities. Recent reviews are given in [Doersch, 2016, Kingma and Welling, 2019, Wei et al., 2020]. After a brief theoretical discussion, we demonstrate the VAE algorithm on the basis of MNIST data.

Introduction

Essentially, the VAE is a joint approach to likelihood-based learning of and Bayesian inference in latent variable models with neural networks. It establishes a generative modeling scheme for data with a latent structure. For data generation, the latent variables are randomly drawn from a prior sampling distribution and then transformed. Vice versa, the posterior of the latent variables for a given data point is computed by amortized variational inference. Both the data-generating and the variational distribution are parametrized through neural networks. The VAE admits an interpretation as an unsupervised learning technique and probabilistic extension of the autoencoder, which is used for dimensionality reduction or anomaly detection.

We start by discussing some simple statistical models that can be parametrized by neural networks. This includes simple binary or multiclass classification and distribution fitting. Let us investigate standard (supervised) classification first. The goal is to predict the class label \(y\) of an input \(\boldsymbol{x}\). A neural network \(\mathrm{NN}_{\boldsymbol{\theta}}(\boldsymbol{x})\) with unknown weights \(\boldsymbol{\theta}\) is used the parametrize a conditional probability distribution \(\pi_{\boldsymbol{\theta}}(y \vert \boldsymbol{x})\). For a multiclass problem the neural network ends with a softmax and models a categorical distribution that can be written using Iverson bracket notation as

\[\pi_{\boldsymbol{\theta}}(y \vert \boldsymbol{x}) = \mathrm{Categorical}(y \vert \mathrm{NN}_{\boldsymbol{\theta}}(\boldsymbol{x})) = \prod_{m=1}^M \mathrm{NN}_{\boldsymbol{\theta},m}(\boldsymbol{x})^{[y=c_m]}, \quad y \in \{c_1, \ldots, c_M\}.\]

The probability of the \(m\)-th class is simply represented by the corresponding output component \(\pi_{\boldsymbol{\theta}}(y=c_m \vert \boldsymbol{x}) = \mathrm{NN}_{\boldsymbol{\theta},m}(\boldsymbol{x})\). Similarly, for a binary classification problem the network would end with a sigmoid function. The probabilities of positive and negative outcomes of a Bernoulli trial are given as \(\pi_{\boldsymbol{\theta}}(y=1 \vert \boldsymbol{x}) = \mathrm{NN}_{\boldsymbol{\theta}}(\boldsymbol{x})\) and \(\pi_{\boldsymbol{\theta}}(y=0 \vert \boldsymbol{x}) = 1 - \mathrm{NN}_{\boldsymbol{\theta}}(\boldsymbol{x})\), respectively. This is summarized as

\[\pi_{\boldsymbol{\theta}}(y \vert \boldsymbol{x}) = \mathrm{Bernoulli}(y \vert \mathrm{NN}_{\boldsymbol{\theta}}(\boldsymbol{x})) = \mathrm{NN}_{\boldsymbol{\theta}}(\boldsymbol{x})^y \, (1 - \mathrm{NN}_{\boldsymbol{\theta}}(\boldsymbol{x}))^{1-y}, \quad y \in \{0, 1\}.\]

Training a classifier, binary or multiclass, requires a set of data that is denoted as \(\mathcal{D} = \{(\boldsymbol{x}_i, y_i)\}_{i=1}^N\). A common approach is to maximize the corresponding likelihood function \(\pi_{\boldsymbol{\theta}}(\mathcal{D}) = \prod_{i=1}^N \pi_{\boldsymbol{\theta}}(y_i \vert \boldsymbol{x}_i)\).

Another simple statistical model discussed for didactic purposes is (unsupervised) distribution fitting. Given a set of data \(\mathcal{D} = \{\boldsymbol{x}_i\}_{i=1}^N\) and a model \(\boldsymbol{X}_i \sim \pi_{\boldsymbol{\theta}}(\boldsymbol{x}_i)\), the likelihood function is written as \(\pi_{\boldsymbol{\theta}}(\mathcal{D}) = \prod_{i=1}^N \pi_{\boldsymbol{\theta}}(\boldsymbol{x}_i)\). The distribution parameters \(\boldsymbol{\theta}\) are fitted by maximizing the log-likelihood

\[\hat{\boldsymbol{\theta}} = \underset{\boldsymbol{\theta}}{\mathrm{argmax}} \, \log \pi_{\boldsymbol{\theta}}(\mathcal{D}), \quad \log \pi_{\boldsymbol{\theta}}(\mathcal{D}) = \sum_{i=1}^N \log \pi_{\boldsymbol{\theta}}(\boldsymbol{x}_i).\]

Now, consider a latent variable model with unobserved variables \(\boldsymbol{z}\) that influence the observed ones \(\boldsymbol{x}\). The joint probability model is often hierarchically factorized as \(\pi_{\boldsymbol{\theta}}(\boldsymbol{x}_i, \boldsymbol{z}_i) = \pi_{\boldsymbol{\theta}_1}(\boldsymbol{x}_i \vert \boldsymbol{z}_i) \, \pi_{\boldsymbol{\theta}_2}(\boldsymbol{z}_i)\). A marginalization over the latent variables yields the likelihood for a single data point as

\[\pi_{\boldsymbol{\theta}}(\boldsymbol{x}_i) = \int \pi_{\boldsymbol{\theta}}(\boldsymbol{x}_i, \boldsymbol{z}_i) \, \mathrm{d} \boldsymbol{z}_i = \int \pi_{\boldsymbol{\theta}_1}(\boldsymbol{x}_i \vert \boldsymbol{z}_i) \, \pi_{\boldsymbol{\theta}_2}(\boldsymbol{z}_i) \, \mathrm{d} \boldsymbol{z}_i.\]

Based on \(\pi_{\boldsymbol{\theta}}(\mathcal{D}) = \prod_{i=1}^N \pi_{\boldsymbol{\theta}}(\boldsymbol{x}_i)\) one could learn the distribution parameters \(\boldsymbol{\theta} = (\boldsymbol{\theta}_1, \boldsymbol{\theta}_2)\) with all data. Another goal would be to infer the latent variables \(\boldsymbol{z}_i\) pertaining to each data point \(\boldsymbol{x}_i\) for \(i=1,\ldots,N\). This can be accomplished by constructing \(i\)-specific posterior distributions

\[\pi_{\boldsymbol{\theta}}(\boldsymbol{z}_i \vert \boldsymbol{x}_i) = \frac{\pi_{\boldsymbol{\theta}}(\boldsymbol{x}_i, \boldsymbol{z}_i)}{\pi_{\boldsymbol{\theta}}(\boldsymbol{x}_i)} = \frac{\pi_{\boldsymbol{\theta}_1}(\boldsymbol{x}_i \vert \boldsymbol{z}_i) \, \pi_{\boldsymbol{\theta}_2}(\boldsymbol{z}_i)}{\pi_{\boldsymbol{\theta}}(\boldsymbol{x}_i)}.\]

Note that the normalizing factor \(\pi_{\boldsymbol{\theta}}(\boldsymbol{x}_i)\) here is exactly the marginalized likelihood of the latent variable model from above. It is also called the model evidence emphasizing its role in model selection.

Evidence lower bound

In order to represent the posterior, one uses a family of distributions \(q_{\boldsymbol{\phi}}(\boldsymbol{z}_i \vert \boldsymbol{x}_i)\) with tunable parameters \(\boldsymbol{\phi}\). The latter can be chosen such that it resembles the actual posterior \(q_{\boldsymbol{\phi}}(\boldsymbol{z}_i \vert \boldsymbol{x}_i) \approx \pi_{\boldsymbol{\theta}}(\boldsymbol{z}_i \vert \boldsymbol{x}_i)\) in some distributional sense. One now needs a measure of the (dis)similarity of two probability densities. To that end, one usually considers the Kullback-Leibler (KL) divergence from the true posterior to the variational distribution. This so-called relative entropy can be written and decomposed as

\[D_{\mathrm{KL}} \left( q_{\boldsymbol{\phi}}(\cdot \vert \boldsymbol{x}_i) \, \| \, \pi_{\boldsymbol{\theta}}(\cdot \vert \boldsymbol{x}_i) \right) = \int q_{\boldsymbol{\phi}}(\boldsymbol{z}_i \vert \boldsymbol{x}_i) \log \left( \frac{q_{\boldsymbol{\phi}}(\boldsymbol{z}_i \vert \boldsymbol{x}_i)}{\pi_{\boldsymbol{\theta}} (\boldsymbol{z}_i \vert \boldsymbol{x}_i)} \right) \, \mathrm{d} \boldsymbol{z}_i = \log \pi_{\boldsymbol{\theta}}(\boldsymbol{x}_i) - \mathcal{F}_{\boldsymbol{\theta}, \boldsymbol{\phi}}(\boldsymbol{x}_i).\]

The last term \(\mathcal{F}_{\boldsymbol{\theta}, \boldsymbol{\phi}}(\boldsymbol{x}_i)\) is called the free energy or the evidence lower bound (ELBO). It owes its name mentioned second, as a direct consequence of \(D_{\mathrm{KL}}(q_{\boldsymbol{\phi}}(\cdot \vert \boldsymbol{x}_i) \, \| \, \pi_{\boldsymbol{\theta}}(\cdot \vert \boldsymbol{x}_i)) \geq 0\), the fact that \(\log \pi_{\boldsymbol{\theta}}(\boldsymbol{x}_i) \geq \mathcal{F}_{\boldsymbol{\theta}, \boldsymbol{\phi}}(\boldsymbol{x}_i)\). The ELBO is given as

\[\mathcal{F}_{\boldsymbol{\theta}, \boldsymbol{\phi}}(\boldsymbol{x}_i) = \int q_{\boldsymbol{\phi}}(\boldsymbol{z}_i \vert \boldsymbol{x}_i) \log \left( \frac{\pi_{\boldsymbol{\theta}_1}(\boldsymbol{x}_i \vert \boldsymbol{z}_i) \, \pi_{\boldsymbol{\theta}_2}(\boldsymbol{z}_i)}{q_{\boldsymbol{\phi}} (\boldsymbol{z}_i \vert \boldsymbol{x}_i)} \right) \, \mathrm{d} \boldsymbol{z}_i.\]

By definition \(D_{\mathrm{KL}}(q_{\boldsymbol{\phi}}(\cdot \vert \boldsymbol{x}_i) \, \| \, \pi_{\boldsymbol{\theta}}(\cdot \vert \boldsymbol{x}_i))\) establishes a distance between the posterior and its approximation. Equivalently, as it can be seen above, it represents the difference between the logarithm of the marginalized likelihood and its ELBO.

Instead of tackling \(N\) \(i\)-specific problems separately, one can solve them jointly for \(i=1,\ldots,N\). One then has \(\log \pi_{\boldsymbol{\theta}}(\mathcal{D}) = \sum_{i=1}^N \log \pi_{\boldsymbol{\theta}}(\boldsymbol{x}_i)\) with \(\log \pi_{\boldsymbol{\theta}}(\boldsymbol{x}_i) = D_{\mathrm{KL}}(q_{\boldsymbol{\phi}}(\cdot \vert \boldsymbol{x}_i) \, \| \, \pi_{\boldsymbol{\theta}}(\cdot \vert \boldsymbol{x}_i)) + \mathcal{F}_{\boldsymbol{\theta}, \boldsymbol{\phi}}(\boldsymbol{x}_i)\). Also, the KL divergence is additive for factorized distributions. Hence, the objective becomes

\[(\hat{\boldsymbol{\theta}}, \hat{\boldsymbol{\phi}}) = \underset{\boldsymbol{\theta}, \boldsymbol{\phi}}{\mathrm{argmax}} \, \mathcal{F}_{\boldsymbol{\theta}, \boldsymbol{\phi}}(\mathcal{D}), \quad \mathcal{F}_{\boldsymbol{\theta}, \boldsymbol{\phi}}(\mathcal{D}) = \sum_{i=1}^N \mathcal{F}_{\boldsymbol{\theta}, \boldsymbol{\phi}}(\boldsymbol{x}_i).\]

Note that the parametrized variational distribution is conditioned on an individual data point. This is reminiscent of but also somewhat different from classical variational inference [Blei et al., 2018] where posteriors are approximated by \(q_{\boldsymbol{\phi}}(\boldsymbol{z}_i) \approx \pi_{\boldsymbol{\theta}}(\boldsymbol{z}_i \vert \boldsymbol{x}_i)\). Essentially, by estimating \(\boldsymbol{\phi}\) it is here learned how to approximately perform Bayesian inference \(q_{\boldsymbol{\phi}}(\boldsymbol{z}_i \vert \boldsymbol{x}_i)\) for any \(\boldsymbol{x}_i\). This is sometimes called amortized variational inference. An interesting discussion of the scheme can be found in [Zhang et al., 2019].

The maximization of the ELBO with respect to the variational parameters \(\boldsymbol{\phi}\) improves the inference model \(q_{\boldsymbol{\phi}}(\boldsymbol{z}_i \vert \boldsymbol{x}_i) \approx \pi_{\boldsymbol{\theta}}(\boldsymbol{z}_i \vert \boldsymbol{x}_i)\) in the sense that it better approximates the posterior distribution. If the ELBO is additionally maximized as function of the generative parameters \(\boldsymbol{\theta}\), one might argue that the model evidence \(\pi_{\boldsymbol{\theta}}(\boldsymbol{x}_i)\) is approximately maximized. More precisely, a lower bound of it is maximized. That can be understood as an approximate version of model selection, where the generative model \(\pi_{\boldsymbol{\theta}}(\boldsymbol{x}_i, \boldsymbol{z}_i) = \pi_{\boldsymbol{\theta}_1}(\boldsymbol{x}_i \vert \boldsymbol{z}_i) \, \pi_{\boldsymbol{\theta}_2}(\boldsymbol{z}_i)\) is improved. Both of those objectives become apparent from simply writing the ELBO as \(\mathcal{F}_{\boldsymbol{\theta}, \boldsymbol{\phi}}(\boldsymbol{x}_i) = \log \pi_{\boldsymbol{\theta}}(\boldsymbol{x}_i) - D_{\mathrm{KL}}(q_{\boldsymbol{\phi}}(\cdot \vert \boldsymbol{x}_i) \, \| \, \pi_{\boldsymbol{\theta}}(\cdot \vert \boldsymbol{x}_i))\).

Encoder and decoder

It is instructive to decompose the ELBO into a term involving the likelihood function and another term involving the prior distribution. This highlights the fact that Bayesian inference finds a compromise between prior and data. The decomposition is

\[\mathcal{F}_{\boldsymbol{\theta}, \boldsymbol{\phi}}(\boldsymbol{x}_i) = \int q_{\boldsymbol{\phi}}(\boldsymbol{z}_i \vert \boldsymbol{x}_i) \log \left( \pi_{\boldsymbol{\theta}_1}(\boldsymbol{x}_i \vert \boldsymbol{z}_i) \right) \, \mathrm{d} \boldsymbol{z}_i - D_{\mathrm{KL}} \left( q_{\boldsymbol{\phi}}(\cdot \vert \boldsymbol{x}_i) \, \| \, \pi_{\boldsymbol{\theta}_2}(\cdot) \right).\]

The first term can be also interpreted as a measure of reconstruction quality, while the second one acts as regularizer. This is reminiscent of a classical autoencoder and motivates the following terminology. The variational distribution \(q_{\boldsymbol{\phi}}(\boldsymbol{z}_i \vert \boldsymbol{x}_i)\) is called the probabilistic encoder, whereas the distribution \(\pi_{\boldsymbol{\theta}_1}(\boldsymbol{x}_i \vert \boldsymbol{z}_i)\) is commonly referred to as the probabilistic decoder.

The variational distribution can be parametrized in different ways. A common representation is a Gaussian \(q_{\boldsymbol{\phi}}(\boldsymbol{z}_i \vert \boldsymbol{x}_i) = \mathcal{N}(\boldsymbol{z}_i \, \vert \, \boldsymbol{\mu}_{\boldsymbol{\phi}}(\boldsymbol{x}_i), \boldsymbol{\Sigma}_{\boldsymbol{\phi}}(\boldsymbol{x}_i))\) whose mean vector \(\boldsymbol{\mu}_{\boldsymbol{\phi}}(\boldsymbol{x}_i)\) and covariance matrix \(\boldsymbol{\Sigma}_{\boldsymbol{\phi}}(\boldsymbol{x}_i)\) are predicted by a neural network, for example

\[q_{\boldsymbol{\phi}}(\boldsymbol{z}_i \vert \boldsymbol{x}_i) = \mathcal{N} \left( \boldsymbol{z}_i \, \vert \, \boldsymbol{\mu}_{\boldsymbol{\phi}}(\boldsymbol{x}_i), \mathrm{diag} \left( \boldsymbol{\sigma}_{\boldsymbol{\phi}}^{\odot 2}(\boldsymbol{x}_i) \right) \right) = \prod_{j=1}^K \mathcal{N} \left( z_{i,j} \, \vert \, \mu_{\boldsymbol{\phi},j}(\boldsymbol{x}_i), \sigma_{\boldsymbol{\phi},j}^2(\boldsymbol{x}_i) \right).\]

Here, the covariance matrix has the diagonal structure \(\boldsymbol{\Sigma}_{\boldsymbol{\phi}}(\boldsymbol{x}_i) = \mathrm{diag}(\sigma_{\boldsymbol{\phi},1}^2(\boldsymbol{x}_i), \ldots, \sigma_{\boldsymbol{\phi},K}^2(\boldsymbol{x}_i))\). Of course, one can construct more complex non-diagonal representations of the covariance.

A deep latent variable model often constitutes the generative part. Take, for instance, a deep latent Gaussian model with a Gaussian prior sampling distribution \(\pi(\boldsymbol{z}_i) = \mathcal{N}(\boldsymbol{z}_i \vert \boldsymbol{0}, \boldsymbol{I})\) and another Gaussian \(\pi_{\boldsymbol{\theta}}(\boldsymbol{x}_i \vert \boldsymbol{z}_i) = \mathcal{N}(\boldsymbol{x}_i \, \vert \, \boldsymbol{\mu}_{\boldsymbol{\theta}}(\boldsymbol{z}_i), \boldsymbol{\Sigma}_{\boldsymbol{\theta}}(\boldsymbol{z}_i))\) as the conditional distribution. An example is

\[\pi_{\boldsymbol{\theta}}(\boldsymbol{x}_i \vert \boldsymbol{z}_i) = \mathcal{N} \left( \boldsymbol{x}_i \, \vert \, \boldsymbol{\mu}_{\boldsymbol{\theta}}(\boldsymbol{z}_i), \mathrm{diag} \left( \boldsymbol{\sigma}_{\boldsymbol{\theta}}^{\odot 2}(\boldsymbol{z}_i) \right) \right) = \prod_{j=1}^L \mathcal{N} \left( x_{i,j} \, \vert \, \mu_{\boldsymbol{\theta},j}(\boldsymbol{z}_i), \sigma_{\boldsymbol{\theta},j}^2(\boldsymbol{z}_i) \right).\]

The vector of means and standard deviations is again computed by a neural network. Instead of the representation of the covariance matrix \(\boldsymbol{\Sigma}_{\boldsymbol{\theta}}(\boldsymbol{z}_i) = \mathrm{diag}(\sigma_{\boldsymbol{\theta},1}^2(\boldsymbol{z}_i), \ldots, \sigma_{\boldsymbol{\theta},L}^2(\boldsymbol{z}_i))\), one can choose a simpler model \(\boldsymbol{\Sigma} = \sigma^2 \boldsymbol{I}\) with a constant (fixed or learnable) standard deviation.

Similarly, a deep latent Bernoulli model can be used for binary data. Here, a neural network predicts the probability of a positive outcome. The conditional model can then be written as

\[\pi_{\boldsymbol{\theta}}(\boldsymbol{x}_i \vert \boldsymbol{z}_i) = \prod_{j=1}^L \mathrm{Bernoulli} \left( x_{i,j} \, \vert \, p_{\boldsymbol{\theta},j}(\boldsymbol{z}_i) \right).\]

It is remarked that, strictly speaking, likelihoods based on the Bernoulli distribution should be only used for \(\{0,1\}\)-valued data [Loaiza-Ganem and Cunningham, 2019, Yong et al, 2020]. Despite that, it is sometimes used for nearly binary data sets such as MNIST with continuous pixel intensities in \([0,1]\).

All in all, we have realized that the VAE actually solves several related problems. As for the generative modeling part, the parameters \(\boldsymbol{\theta}\) of a generative model \(\pi_{\boldsymbol{\theta}}(\boldsymbol{x}_i, \boldsymbol{z}_i) = \pi_{\boldsymbol{\theta}_1}(\boldsymbol{x}_i \vert \boldsymbol{z}_i) \, \pi_{\boldsymbol{\theta}_2}(\boldsymbol{z}_i)\) are inferred. At the same time, the parameters \(\boldsymbol{\phi}\) of an inference model \(q_{\boldsymbol{\phi}}(\boldsymbol{z}_i \vert \boldsymbol{x}_i) \approx \pi_{\boldsymbol{\theta}}(\boldsymbol{z}_i \vert \boldsymbol{x}_i)\) approximating the posterior distribution are estimated. The inference model corresponds to the jointly learned generator. In addition, an autoencoder-like latent representation or code \(\boldsymbol{z}_i\) of the original data \(\boldsymbol{x}_i\) has emerged.

Posterior collapse

A number of phenomena where the posterior \(q_{\boldsymbol{\phi}}(\boldsymbol{z}_i \vert \boldsymbol{x}_i) \approx \pi_{\boldsymbol{\theta}_2}(\boldsymbol{z}_i)\) merely approximates the prior, at least for some dimensions of \(\boldsymbol{z}_i\), are sometimes referred to as posterior collapse. Several modes of such a collapse in VAE architectures have been discussed in the literature [Chen et al., 2017, Mattei and Frellsen, 2018, Lucas et al., 2019, Dai et al., 2020]. They can be understood through an investigation of the “balance” between the likelihood and KL terms in \(\mathcal{F}_{\boldsymbol{\theta}, \boldsymbol{\phi}}(\boldsymbol{x}_i) = \int q_{\boldsymbol{\phi}}(\boldsymbol{z}_i \vert \boldsymbol{x}_i) \log(\pi_{\boldsymbol{\theta}_1}(\boldsymbol{x}_i \vert \boldsymbol{z}_i)) \, \mathrm{d} \boldsymbol{z}_i - D_{\mathrm{KL}}(q_{\boldsymbol{\phi}}(\cdot \vert \boldsymbol{x}_i) \, \| \, \pi_{\boldsymbol{\theta}_2}(\cdot))\).

Some forms of posterior collapse are related to the choice of the covariance model of the decoder \(\pi_{\boldsymbol{\theta}_1}(\boldsymbol{x}_i \vert \boldsymbol{z}_i) = \mathcal{N}(\boldsymbol{x}_i \, \vert \, \boldsymbol{\mu}_{\boldsymbol{\theta}_1}(\boldsymbol{z}_i), \boldsymbol{\Sigma}_{\boldsymbol{\theta}_1}(\boldsymbol{z}_i))\). A simple representation of the covariance matrix as \(\boldsymbol{\Sigma} = \sigma^2 \boldsymbol{I}\) with a constant and too large value of the variance \(\sigma^2\) would result in a dominance of the KL term \(D_{\mathrm{KL}}(q_{\boldsymbol{\phi}}(\cdot \vert \boldsymbol{x}_i) \, \| \, \pi_{\boldsymbol{\theta}_2}(\cdot))\) and eventually encourage the posterior to equal the prior.

For expressive models that represent the covariance as a function \(\boldsymbol{\Sigma}_{\boldsymbol{\theta}}(\boldsymbol{z}_i) = \mathrm{diag}(\sigma_{\boldsymbol{\theta},1}^2(\boldsymbol{z}_i), \ldots, \sigma_{\boldsymbol{\theta},L}^2(\boldsymbol{z}_i))\), if one would have \(x_{i,j} \approx \mu_{\boldsymbol{\theta},j}(\boldsymbol{z}_i)\), the likelihood could be arbitrarily maximized by decreasing the variance \(\sigma_{\boldsymbol{\theta},j}^2(\boldsymbol{z}_i)\). As a consequence of this collapse of the likelihood around a data point, the model does not necessarily learn anything useful beyond that point.

Another reported form of posterior collapse might happen in case that powerful decoders \(\pi_{\boldsymbol{\theta}_1}(\boldsymbol{x}_i \vert \boldsymbol{z}_i)\) are used. They typically do not have the simple Gaussian structure discussed above, but an autoregressive form of dependencies for example. If such a representation is expressive enough to assign a high value to the data point \(\boldsymbol{x}_i\) without utilizing the latent variables \(\boldsymbol{z}_i\), the posterior collapses to the prior as per Bayes’ law.

Stochastic optimization

For a practical solution of the stochastic optimization problem that arises, the ELBO \(\mathcal{F}_{\boldsymbol{\theta}, \boldsymbol{\phi}}(\boldsymbol{x}_i)\) or its gradient \(\nabla_{\boldsymbol{\theta}, \boldsymbol{\phi}} \mathcal{F}_{\boldsymbol{\theta}, \boldsymbol{\phi}}(\boldsymbol{x}_i)\) with respect to the parameters of the probabilistic decoder and encoder need to be evaluated. To that end, one may start from one the following two equivalent ways of writing the optimization objective as an expected value

\[\mathcal{F}_{\boldsymbol{\theta}, \boldsymbol{\phi}}(\boldsymbol{x}_i) = \mathbb{E}_{q_{\boldsymbol{\phi}}(\cdot \vert \boldsymbol{x}_i)} \left[ \log \left( \pi_{\boldsymbol{\theta}}(\boldsymbol{x}_i, \boldsymbol{Z}_i) \right) - q_{\boldsymbol{\phi}} (\boldsymbol{Z}_i \vert \boldsymbol{x}_i) \right] = \mathbb{E}_{q_{\boldsymbol{\phi}}(\cdot \vert \boldsymbol{x}_i)} \left[ \log \left( \pi_{\boldsymbol{\theta}_1}(\boldsymbol{x}_i \vert \boldsymbol{Z}_i) \right) \right] - D_{\mathrm{KL}} \left( q_{\boldsymbol{\phi}}(\cdot \vert \boldsymbol{x}_i) \, \| \, \pi_{\boldsymbol{\theta}_2}(\cdot) \right).\]

The second way is beneficial if the KL term \(D_{\mathrm{KL}}(q_{\boldsymbol{\phi}}(\cdot \vert \boldsymbol{x}_i) \, \| \, \pi_{\boldsymbol{\theta}_2}(\cdot))\) can be evaluated analytically, which happens for instance if both distributions are Gaussian. In the simplest case that \(\pi_{\boldsymbol{\theta}_2}(\cdot) = \mathcal{N}(\boldsymbol{0}, \boldsymbol{I})\) and \(q_{\boldsymbol{\phi}}(\cdot \vert \boldsymbol{x}_i) = \mathcal{N}((\mu_{\boldsymbol{\phi},1}(\boldsymbol{x}_i),\ldots,\mu_{\boldsymbol{\phi},K}(\boldsymbol{x}_i)), \mathrm{diag}(\sigma_{\boldsymbol{\phi},1}^2(\boldsymbol{x}_i),\ldots,\sigma_{\boldsymbol{\phi},K}^2(\boldsymbol{x}_i)))\), one has

\[D_{\mathrm{KL}} \left( \mathcal{N}((\mu_1, \ldots, \mu_K), \mathrm{diag}(\sigma_1^2, \ldots, \sigma_K^2)) \, \| \, \mathcal{N}(\boldsymbol{0}, \boldsymbol{I}) \right) = \frac{1}{2} \sum_{j=1}^K \left( \mu_j^2 + \sigma_j^2 - \log \sigma_j^2 - 1 \right).\]

The likelihood term \(\mathbb{E}_{q_{\boldsymbol{\phi}}(\cdot \vert \boldsymbol{x}_i)}[\log(\pi_{\boldsymbol{\theta}_1}(\boldsymbol{x}_i \vert \boldsymbol{Z}_i))]\) turns out to be more involved. It has to be simulated via Monte Carlo sampling in most practical situations. Given samples \(\boldsymbol{z}_i^{(t)} \sim q_{\boldsymbol{\phi}}(\cdot \vert \boldsymbol{x}_i)\) for \(t=1,\ldots,T\) from the variational distribution, the likelihood term can be estimated as

\[\mathbb{E}_{q_{\boldsymbol{\phi}}(\cdot \vert \boldsymbol{x}_i)} \left[ \log \left( \pi_{\boldsymbol{\theta}_1}(\boldsymbol{x}_i \vert \boldsymbol{Z}_i) \right) \right] \approx \frac{1}{T} \sum_{t=1}^T \log \left( \pi_{\boldsymbol{\theta}_1}(\boldsymbol{x}_i \vert \boldsymbol{z}_i^{(t)}) \right).\]

Due to their wide applicability, such sampling-based approaches are also called black box variational inference [Ranganath et al., 2014]. A special case of such a scheme uses only a single sample \(T=1\) in order to estimate the likelihood term as \(\mathbb{E}_{q_{\boldsymbol{\phi}}(\cdot \vert \boldsymbol{x}_i)}[\log(\pi_{\boldsymbol{\theta}_1}(\boldsymbol{x}_i \vert \boldsymbol{Z}_i))] \approx \log(\pi_{\boldsymbol{\theta}_1}(\boldsymbol{x}_i \vert \boldsymbol{z}_i^{(1)}))\).

At this point, the sampling efficiency of the Monte Carlo approach needs to be discussed briefly. Since the variational distribution \(q_{\boldsymbol{\phi}}(\cdot \vert \boldsymbol{x}_i)\) for estimating the term \(\int q_{\boldsymbol{\phi}}(\boldsymbol{z}_i \vert \boldsymbol{x}_i) \log(\pi_{\boldsymbol{\theta}_1}(\boldsymbol{x}_i \vert \boldsymbol{z}_i)) \, \mathrm{d} \boldsymbol{z}_i\) is expected to better and better approximate the posterior during training, sampling from it produces high-likelihood samples. Such latent values have likely generated the data point under the generative model. This is significantly more efficient than sampling from the prior distribution \(\pi_{\boldsymbol{\theta}_2}(\cdot)\), typically yielding low-likelihood samples, in order to estimate the marginal likelihood \(\pi_{\boldsymbol{\theta}}(\boldsymbol{x}_i) = \int \pi_{\boldsymbol{\theta}_1}(\boldsymbol{x}_i \vert \boldsymbol{z}_i) \, \pi_{\boldsymbol{\theta}_2}(\boldsymbol{z}_i) \, \mathrm{d} \boldsymbol{z}_i\).

While the gradient \(\nabla_{\boldsymbol{\theta}} \mathcal{F}_{\boldsymbol{\theta}, \boldsymbol{\phi}}(\boldsymbol{x}_i)\) can be estimated by using a sampling strategy, it does not work so easily for \(\nabla_{\boldsymbol{\phi}} \mathcal{F}_{\boldsymbol{\theta}, \boldsymbol{\phi}}(\boldsymbol{x}_i)\). The reason is that one can readily use \(\nabla_{\boldsymbol{\theta}_1} \mathbb{E}_{q_{\boldsymbol{\phi}}(\cdot \vert \boldsymbol{x}_i)}[\log(\pi_{\boldsymbol{\theta}_1}(\boldsymbol{x}_i \vert \boldsymbol{Z}_i))] = \mathbb{E}_{q_{\boldsymbol{\phi}}(\cdot \vert \boldsymbol{x}_i)}[\nabla_{\boldsymbol{\theta}_1} \log(\pi_{\boldsymbol{\theta}_1}(\boldsymbol{x}_i \vert \boldsymbol{Z}_i))] \approx T^{-1} \sum_t \nabla_{\boldsymbol{\theta}_1} \log(\pi_{\boldsymbol{\theta}_1}(\boldsymbol{x}_i \vert \boldsymbol{z}_i^{(t)}))\), whereas it is not immediately clear how one should proceed with \(\nabla_{\boldsymbol{\phi}} \mathbb{E}_{q_{\boldsymbol{\phi}}(\cdot \vert \boldsymbol{x}_i)}[\log(\pi_{\boldsymbol{\theta}_1}(\boldsymbol{x}_i \vert \boldsymbol{Z}_i))] = \int \log(\pi_{\boldsymbol{\theta}_1}(\boldsymbol{x}_i \vert \boldsymbol{z}_i)) \, \nabla_{\boldsymbol{\phi}} q_{\boldsymbol{\phi}}(\boldsymbol{z}_i \vert \boldsymbol{x}_i) \, \mathrm{d} \boldsymbol{z}_i\). Here, the gradient occurs with respect to a distributional parameter of the expectation. The latter integral does not have the mathematical form of an expectation.

A review of noisy gradient estimation for stochastic optimization is provided in [Mohamed et al., 2020]. The so-called score function estimator is unbiased but typically has a high variance. A better alternative is often the reparametrization trick. Here, one introduces an auxiliary random variable \(\boldsymbol{\mathcal{E}}_i \sim p\) following a simple distribution and a random variable transformation with \(\boldsymbol{Z}_i = g_{\boldsymbol{\phi}}(\boldsymbol{\mathcal{E}}_i, \boldsymbol{x}_i) \sim q_{\boldsymbol{\phi}}(\cdot \vert \boldsymbol{x}_i)\). One can then write

\[\mathbb{E}_{q_{\boldsymbol{\phi}}(\cdot \vert \boldsymbol{x}_i)} \left[ \log \left( \pi_{\boldsymbol{\theta}_1}(\boldsymbol{x}_i \vert \boldsymbol{Z}_i) \right) \right] = \mathbb{E}_p \left[ \log \left( \pi_{\boldsymbol{\theta}_1}(\boldsymbol{x}_i \vert g_{\boldsymbol{\phi}}(\boldsymbol{\mathcal{E}}_i, \boldsymbol{x}_i)) \right) \right].\]

This follows from the law of the unconscious statistician. Now the expectation is taken with respect to a fixed base distribution. It therefore commutes with the gradient operator such that one can approximate the required gradient through Monte Carlo sampling

\[\nabla_{\boldsymbol{\theta}_1, \boldsymbol{\phi}} \mathbb{E}_p \left[ \log \left( \pi_{\boldsymbol{\theta}_1}(\boldsymbol{x}_i \vert g_{\boldsymbol{\phi}}(\boldsymbol{\mathcal{E}}_i, \boldsymbol{x}_i)) \right) \right] = \mathbb{E}_p \left[ \nabla_{\boldsymbol{\theta}_1, \boldsymbol{\phi}} \log \left( \pi_{\boldsymbol{\theta}_1}(\boldsymbol{x}_i \vert g_{\boldsymbol{\phi}}(\boldsymbol{\mathcal{E}}_i, \boldsymbol{x}_i)) \right) \right] \approx \frac{1}{T} \sum_{t=1}^T \nabla_{\boldsymbol{\theta}_1, \boldsymbol{\phi}} \log \left( \pi_{\boldsymbol{\theta}_1}(\boldsymbol{x}_i \vert g_{\boldsymbol{\phi}}(\boldsymbol{\epsilon}_i^{(t)}, \boldsymbol{x}_i)) \right).\]

The samples \(\boldsymbol{\epsilon}_i^{(t)} \sim p\) are drawn from the base distribution for \(t=1,\ldots,T\). In case that the base is \(p = \mathcal{N}(\boldsymbol{0}, \boldsymbol{I})\) and the variational distribution is \(q_{\boldsymbol{\phi}}(\cdot \vert \boldsymbol{x}_i) = \mathcal{N}(\boldsymbol{\mu}_{\boldsymbol{\phi}}(\boldsymbol{x}_i), \mathrm{diag}(\boldsymbol{\sigma}_{\boldsymbol{\phi}}^{\odot 2}(\boldsymbol{x}_i)))\), the reparametrization is appealingly simple. It is just given as

\[\boldsymbol{z}_i^{(t)} = g_{\boldsymbol{\phi}}(\boldsymbol{\epsilon}_i^{(t)}, \boldsymbol{x}_i) = \boldsymbol{\mu}_{\boldsymbol{\phi}}(\boldsymbol{x}_i) + \boldsymbol{\sigma}_{\boldsymbol{\phi}}(\boldsymbol{x}_i) \odot \boldsymbol{\epsilon}_i^{(t)}.\]

Eventually, one has an estimator of the ELBO gradient based on Monte Carlo simulation at hand. This noisy gradient can be used in conjunction with all available backpropagation-based optimization schemes to maximize the ELBO.