Deep dive into variational auto-encoder (Part 1)
This is a deep dive to machine learning technique variational auto-encoder (VAE).
General Theory
Model
The digit images (\(x\)) are generated by an unknown process \(p_{\theta^*}(x\vert z)\), where \(z\) is an unobservable latent variable. Here we plug in a specific \(z\) into this function, and it specifies the intensity distribution of the image pixels of \(x\).
We don’t know the true parameter \(\theta^*\) but we know the general form of the parametric function \(p_{\theta}(x\vert z)\). Note that \(\theta\) denotes variable of parameters, while \(\theta^*\) denotes the truth value. Latent variable \(z\) follows distribution \(p_{\theta^*}(z)\).
The likelihood function \(p_{\theta}(x)\) measures how well \(\theta\) describes the observed digit \(x\). By definition \(p_{\theta^*}(x) = 1\) because \(\theta^*\) generates the observation. However, \(p_{\theta}(x)\) is intractable, meaning that we cannot evaluate or differentiate \(p_{\theta}(x)\) for every \(\theta\) and \(x\). Same for \(p_{\theta}(x\vert z)\) and \(p_{\theta}(z\vert x)\). Since \(p_{\theta}(z\vert x)\) is hard to evaluate, we introduce a new function \(q_{\phi}(z\vert x)\) to approximate it. The following diagram sums it all up:
distribution P(z)") --> A A(("latent
variable z")) --> H(Hidden
process ) H -->X((observed
image x)) X-->Q("q(z|X)") Q-->P style H fill:#bbf,stroke:#f66,stroke-width:2px,color:#fff,stroke-dasharray: 5, 5
In summary, below are what each variable means.
Variable | Description |
---|---|
\(x\) | image of digit |
\(z\) | latent variable that generates the image |
\(\theta\) | Parameter of the true model \(p_{\theta}\) |
\(\theta^*\) | True parameter. \(p_{\theta^*}(x)=1\) is maximum |
\(p_{\theta}(x)\) | Likelihood function - How likely \(\theta\) describes \(x\) |
\(p_{\theta}(z)\) | Prior distribution of \(z\) |
\(p_{\theta}(x\vert z)\) | Likelihood or generative function of \(x\) given \(z\) |
\(p_{\theta}(z\vert x)\) | Likelihood or generative function of \(z\) given \(x\) |
\(q_{\phi}(z\vert x)\) | Approximate function to \(p_{\theta}(z\vert x)\) |
Objective function
In variational auto encoder, the objective function is the likelihood function of the true model \(p_{\theta}(x)\). Ideally we want to find the optimal \(\theta\) that best matches the observed images \(x\). \(p_{\theta}(x)\) can be expressed in terms of Kullback-Leiber divergence \(D_{KL}\). \(D_{KL}(q_{\phi}(z\vert x) || p_{\theta(z\vert x) })\) measures how well \(q_{\phi}(z\vert x)\) approximates \(p_{\theta(z\vert x)}\):
\[D_{KL}(q_{\phi}(z\vert x) \vert \vert p_{\theta(z\vert x) }) = \mathbb{E}_{q_{\phi}(z\vert x)}[\ln q_{\phi}(z\vert x)-\ln p(x,z) + \ln p_{\theta}(x)]\]or
\[\ln p_{\theta}(x) = D_{KL}(q_{\phi}(z\vert x) \vert \vert p_{\theta}(z\vert x) ) + \mathcal{L}.\]where
\[\mathcal{L} = \mathbb{E}_{q_{\phi}(z\vert x)}[-\ln q_{\phi}(z\vert x)+\ln p_{\theta}(x,z) ]\]\(\ln p_{\theta}(x)\) can be taken out of the expectation because it does not depend on \(z\). Since $D_{KL}$ is non-negative, \(\mathcal{L}\) serves as a lower bound to \(\ln p_{\theta}(x)\). In other words:
\[\ln p_{\theta}(x) \ge \mathcal{L}\]The trick of VAE is to maximize \(\mathcal{L}\) instead of \(\ln p_{\theta}(x)\) because it can be calculated for many problems.
But since \(p(x,y)\) is hard to calculate, it is useful to rewrite \(\mathcal{L}\) as
\[\begin{array} \mathcal{L} &= \mathbb{E}_{q_{\phi}(z\vert x)}[-\ln q_{\phi}(z\vert x)+\ln p_{\theta}(x,z) ]\\ &= \mathbb{E}_{q_{\phi}(z\vert x)}[-\ln q_{\phi}(z\vert x)+\ln p_{\theta}(z)+\ln p_{\theta}(x|z) ]\\ &= -D_{KL}(q_{\phi}(z\vert x) \vert \vert p_{\theta } (z)) + \mathbb{E}_{q_{\phi}(z\vert x)} [\ln p_{\theta}(x|z)] \end{array}\]Model for MNIST dataset
To trian the model, we will maximize the lower bound
\[\mathcal{L} = -D_{KL}(q_{\phi}(z\vert x) \vert \vert p_{\theta } (z)) + \mathbb{E}_{q_{\phi}(z\vert x)} [\ln p_{\theta}(x|z)]\]For this problem, we choose the prior distribution \(p_{\theta}(z)\) of the latent variable \(z\) to be the standard normal distribution which has zero mean and unit variance, i.e. \(\mathcal{N}(0,1)\). Why can we do that? Because we don’t know the distribution and may as well choose to work with a easier one! But we will see in a moment that for the purpose it will serve it doesn’t really matter.
OK, that takes care of \(p_{\theta}(z)\). How about \(q_{\phi}(z\vert x)\)? To be consistent, we also model it with a normalize distribution, but it could have non-zero mean and non-unity variance. Mathematically,
\[q_{\phi}(z\vert x) = \mathcal{N}(\mu, \sigma^2)\]How can it be different from the prior? The idea is, we have 10 digits to encode. Each one will have distribution deviates from zero and collectively they form 10 distinct clusters. But if we look at $z$ over all digits, it will still follow the standard normal distribution \(\mathcal{N}(0,1)\). In a moment, you will see the prior \(p_{\theta}(z)\) regularizes the learned parameters \(\mu\) and \(\sigma\) to pull them back to the standard normal \(\mathcal{N}(0,1)\).
Now we can calculate the first term, \(D_{KL}(q_{\phi}(z\vert x) \vert \vert p_{\theta } (z))\). Using the identity of KL divergence between two normal distributions
\[D_{KL}(\mathcal{N}(\mu_1, \sigma_1^2)\vert\vert \mathcal{N}(\mu_2, \sigma_2^2)) = \ln \frac{\sigma_2}{\sigma_1} + \frac{\sigma_1^2 + (\mu_1 - \mu_2)^2}{2 \sigma_2^2} - \frac{1}{2}\]We get the first term
\[\begin{array} --D_{KL}(q_{\phi}(z\vert x) \vert \vert p_{\theta } (z)) &= -D_{KL}(\mathcal{N}(\mu, \sigma^2) \vert \vert \mathcal{N}(0, 1) )\\ &= \frac{1}{2} \sum_\limits{j=1}^J(1 + \ln \sigma_j^2 -\sigma_j^2 - \mu_j^2) \end{array}\]$J$ is the dimension of the latent variable $z$.
Let’s build some intuition!
- $\ln \sigma_j^2 -\sigma_j^2$ is maximized at $\sigma_j^2 = 1$ (plot)
- $-\mu_j^2$ is maximized when $\mu_j = 0$ (hope this is obvious…)
So the learning prefers $z$ to follow the standard normal distribution as much as possible. In other words, the first term is a regularization term to make sure that the learned distribution of $z$ is not too crazy.
OK, let’s understand the second term \(\mathbb{E}_{q_{\phi}(z\vert x)} [\ln p_{\theta}(x\vert z)]\). For binary images (we will use binarized MNIST), $p_{\theta}(x\vert z)$ is Bernoulli distribution . For a single pixel, it is
\[p_{\theta}(x|z) =x^y(1-x)^{(1-y)}\]$y$ is the observed pixel intensity and can only be 0 or 1 (binary). Basically, this is a measure of how well the random variable $x$ matches the observed intensity. Suppose it is a dark pixel ($y=0$) and the model predicts $x=0.1$, the model would be considered to be doing pretty well and score 0.9 (with 1.0 being the full mark). Note that you don’t see $z$ on right hand side of the equation because it is implicitly in $x$ as $x$ is generated by $z$. For an image with $L$ pixels,
\[p_{\theta}(x|z) =\prod\limits_{i=1}^{L} x_i^{y_i}(1-x_i)^{(1-y_i)}\]How to evaluate the expectation value \(\mathbb{E}_{q_{\phi}(z\vert x)}\)? Since we draw $z$ from the approximate $q_{\phi}(z\vert x)$, we can take an simple average over all images, or use only one image per evaluation and don’t worry about it.
References
Auto-Encoding Variational Bayes - original paper