This is a deep dive to machine learning technique variational auto-encoder (VAE).

General Theory

Model

The digit images (x) are generated by an unknown process pθ(x|z), where z is an unobservable latent variable. Here we plug in a specific z into this function, and it specifies the intensity distribution of the image pixels of x.

We don’t know the true parameter θ but we know the general form of the parametric function pθ(x|z). Note that θ denotes variable of parameters, while θ denotes the truth value. Latent variable z follows distribution pθ(z).

The likelihood function pθ(x) measures how well θ describes the observed digit x. By definition pθ(x)=1 because θ generates the observation. However, pθ(x) is intractable, meaning that we cannot evaluate or differentiate pθ(x) for every θ and x. Same for pθ(x|z) and pθ(z|x). Since pθ(z|x) is hard to evaluate, we introduce a new function qϕ(z|x) to approximate it. The following diagram sums it all up:

Draw from
distribution P(z)
latent
variable z
Hidden
process
observed
image x
q(z|X)

In summary, below are what each variable means.

Variable Description
x image of digit
z latent variable that generates the image
θ Parameter of the true model pθ
θ True parameter. pθ(x)=1 is maximum
pθ(x) Likelihood function - How likely θ describes x
pθ(z) Prior distribution of z
pθ(x|z) Likelihood or generative function of x given z
pθ(z|x) Likelihood or generative function of z given x
qϕ(z|x) Approximate function to pθ(z|x)

Objective function

In variational auto encoder, the objective function is the likelihood function of the true model pθ(x). Ideally we want to find the optimal θ that best matches the observed images x. pθ(x) can be expressed in terms of Kullback-Leiber divergence DKL. DKL(qϕ(z|x)||pθ(z|x)) measures how well qϕ(z|x) approximates pθ(z|x):

DKL(qϕ(z|x)||pθ(z|x))=Eqϕ(z|x)[lnqϕ(z|x)lnp(x,z)+lnpθ(x)]

or

lnpθ(x)=DKL(qϕ(z|x)||pθ(z|x))+L.

where

L=Eqϕ(z|x)[lnqϕ(z|x)+lnpθ(x,z)]

lnpθ(x) can be taken out of the expectation because it does not depend on z. Since DKL is non-negative, L serves as a lower bound to lnpθ(x). In other words:

lnpθ(x)L

The trick of VAE is to maximize L instead of lnpθ(x) because it can be calculated for many problems.

But since p(x,y) is hard to calculate, it is useful to rewrite L as

L=Eqϕ(z|x)[lnqϕ(z|x)+lnpθ(x,z)]=Eqϕ(z|x)[lnqϕ(z|x)+lnpθ(z)+lnpθ(x|z)]=DKL(qϕ(z|x)||pθ(z))+Eqϕ(z|x)[lnpθ(x|z)]

Model for MNIST dataset

To trian the model, we will maximize the lower bound

L=DKL(qϕ(z|x)||pθ(z))+Eqϕ(z|x)[lnpθ(x|z)]

For this problem, we choose the prior distribution pθ(z) of the latent variable z to be the standard normal distribution which has zero mean and unit variance, i.e. N(0,1). Why can we do that? Because we don’t know the distribution and may as well choose to work with a easier one! But we will see in a moment that for the purpose it will serve it doesn’t really matter.

OK, that takes care of pθ(z). How about qϕ(z|x)? To be consistent, we also model it with a normalize distribution, but it could have non-zero mean and non-unity variance. Mathematically,

qϕ(z|x)=N(μ,σ2)

How can it be different from the prior? The idea is, we have 10 digits to encode. Each one will have distribution deviates from zero and collectively they form 10 distinct clusters. But if we look at z over all digits, it will still follow the standard normal distribution N(0,1). In a moment, you will see the prior pθ(z) regularizes the learned parameters μ and σ to pull them back to the standard normal N(0,1).

Now we can calculate the first term, DKL(qϕ(z|x)||pθ(z)). Using the identity of KL divergence between two normal distributions

DKL(N(μ1,σ21)||N(μ2,σ22))=lnσ2σ1+σ21+(μ1μ2)22σ2212

We get the first term

DKL(qϕ(z|x)||pθ(z))=DKL(N(μ,σ2)||N(0,1))=12Jj=1(1+lnσ2jσ2jμ2j)

J is the dimension of the latent variable z.

Let’s build some intuition!

  • lnσ2jσ2j is maximized at σ2j=1 (plot)
  • μ2j is maximized when μj=0 (hope this is obvious…)

So the learning prefers z to follow the standard normal distribution as much as possible. In other words, the first term is a regularization term to make sure that the learned distribution of z is not too crazy.

OK, let’s understand the second term Eqϕ(z|x)[lnpθ(x|z)]. For binary images (we will use binarized MNIST), pθ(x|z) is Bernoulli distribution . For a single pixel, it is

pθ(x|z)=xy(1x)(1y)

y is the observed pixel intensity and can only be 0 or 1 (binary). Basically, this is a measure of how well the random variable x matches the observed intensity. Suppose it is a dark pixel (y=0) and the model predicts x=0.1, the model would be considered to be doing pretty well and score 0.9 (with 1.0 being the full mark). Note that you don’t see z on right hand side of the equation because it is implicitly in x as x is generated by z. For an image with L pixels,

pθ(x|z)=Li=1xyii(1xi)(1yi)

How to evaluate the expectation value Eqϕ(z|x)? Since we draw z from the approximate qϕ(z|x), we can take an simple average over all images, or use only one image per evaluation and don’t worry about it.

References

Auto-Encoding Variational Bayes - original paper