MNIST Variational Autoencoder

The latent space representation of the MNIST dataset.

The goal of a Variational Autoencoder is to create a generative model, such that we can call it and it will generate a sample that mimics the training dataset. Variational Autoencoders make the assumption that each of the samples $ \mathbf x_i$ from the data set $\mathbf{D}$ are iid and generated from the same probability distribution. Along with this, each sample has an latent feature vector $\mathbf{z}_i$. This implies the existence of a full joint distribution $p(\mathbf{x},\mathbf{z})$. When creating a generative model we are trying to discover the distribution $p(\mathbf x | \mathbf z)$, whereas when creating a discriminative model we are trying to discover the distribution $p(\mathbf{z}|\mathbf{x})$. VAEs are an attempt to utilize this latent structure to make optimization easier.

To start, let us first get our imports as well as our datset. For the purpose of this tutorial we will be using the MNIST dataset.

import matplotlib.pyplot as plt                    
import numpy as np                                 
import torch as torch                              
import torchvision.datasets as datasets
from torchvision.transforms import ToTensor 
from import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

class data(Dataset):
  def __init__(self, X, Y):
    self.X = X
    self.Y = Y
    if len(self.X) != len(self.Y):
      raise Exception("len(X) != len(Y)")

  def __len__(self):
    return len(self.X)

  def __getitem__(self, index):
    _x = self.X[index].unsqueeze(dim=0)
    _y = self.Y[index].unsqueeze(dim=0)

    return _x, _y

# Importing MNIST
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=ToTensor())
mnist_testset = datasets.MNIST(root='./data', train=True, download=True, transform=ToTensor())

bs = 200

# Data Loader
train_loader =, batch_size=bs, shuffle=True)
test_loader =, batch_size=bs, shuffle=False)

Lets consider what a particular sample from the dataset represents. We can take the image and then reshape it into a vector. Lets represent a particular sample by the symbol $\mathbf x_i$. We can represnt the probability of that element by assuming it is conditioned on some feature variable $\mathbf z_i$.


I am not interested in learning the distribution of $p(\mathbf x|\mathbf \phi)$ itself. Rather I want to know the conditionals $p(\mathbf x | \mathbf z, \mathbf \phi)$ and $p(\mathbf z| \mathbf x, \mathbf \phi)$. Where $\mathbf z$ is a “feature” space. This allows us to use the generative nature of $p(\mathbf x | \mathbf z, \mathbf \phi)$. We can choose a particular set of features $\mathbf z$ and then see which $\mathbf x$ probabilistically corresponds to it. Through all this we are then asserting that there exists a total joint distribution $p(\mathbf x, \mathbf z| \mathbf \phi)$. It is related to the conditionals through the following expression. Noting that $\mathbf \phi$ will represent the parameters of our model.

$$ p(\mathbf x|\mathbf \phi) = \int p(\mathbf{x},\mathbf{z}|\mathbf \phi) d\mathbf{z} $$

Sadly, I will be restricting our use of the actual image labels in MNIST. This is to make this more generic to image generation in general. We only have observed samples from $p(\mathbf x|\mathbf \phi)$. We can then write the objection fuction using a maximum likelihood treatment as:

$$ \log p(\mathbf x|\mathbf \phi) = \log \bigg[\int p(\mathbf{x},\mathbf{z}|\mathbf \phi) d\mathbf{z}\bigg] $$

Sadly the integral in the equation above makes it quite difficult to deal with. Therefore, we can instead make an approximate model for this system by using two neural networks. One of these networks will be the encoder, the $p(\mathbf x|\mathbf \phi)$ that we have seen thus far, and the other will be a decoder. I am going to use the function $q(\mathbf x |\mathbf z, \mathbf \theta)$ to represent the decoder where $\mathbf \theta$ are the network parameters. Similarly, the decoder will be $q(\mathbf z |\mathbf x, \mathbf \theta)$.

$$ \mathbf \mu, \mathbf \Sigma = \text{Encoder}_\theta(\mathbf x) $$ $$ q(\mathbf z |\mathbf x, \mathbf \theta) = \mathcal{N}(\mathbf z | \mathbf \mu, \mathbf \Sigma) $$

Then notice that we can write the objective function can then be written as:

$$ \log p(\mathbf x |\mathbf \phi) = \mathbb{E}_{q(\mathbf z |\mathbf x, \mathbf \theta)}[\log p(\mathbf x |\mathbf \phi)] $$

Because the integral introduced by the expectation doesn’t depend on $\mathbb x$ we can just take it out and let $q(\mathbf z |\mathbf x, \mathbf \theta)$ integrate itself to 1. This expression alows for more decomposition.

$$ \mathbb{E}_{q(\mathbf z |\mathbf x, \mathbf \theta)}\bigg[\log p(\mathbf x |\mathbf \phi)\bigg] $$

$$ = \mathbb{E}_{q(\mathbf z |\mathbf x, \mathbf \theta)}\bigg[\log\bigg( \frac{p(\mathbf x,\mathbf z |\mathbf \phi) }{p(\mathbf z|\mathbf \phi,\mathbf x)}\bigg) \bigg] $$

This is because the marginalized probability can be written in terms of the joint and a conditional. Further then we can simply multiply the top and bottom by the same quantitiy and separate the logarithm into a sum.

$$ = \mathbb{E}_{q(\mathbf z |\mathbf x, \mathbf \theta)}\bigg[\log\bigg( \frac{p(\mathbf x,\mathbf z|\mathbf \phi)q(\mathbf z |\mathbf x, \mathbf \theta)}{p(\mathbf z|\mathbf \phi, \mathbf x)q(\mathbf z |\mathbf x, \mathbf \theta)}\bigg) \bigg] $$ . $$ = \mathbb{E}_{q(\mathbf z |\mathbf x, \mathbf \theta)}\bigg[\log\bigg( \frac{p(\mathbf x,\mathbf z|\mathbf \phi)}{q(\mathbf z |\mathbf x, \mathbf \theta)}\bigg) \bigg] + \mathbb{E}_{q(\mathbf z |\mathbf x, \mathbf \theta)}\bigg[\log\bigg( \frac{q(\mathbf z |\mathbf x, \mathbf \theta)}{p(\mathbf z|\mathbf \phi, \mathbf x)}\bigg) \bigg] $$

We can recognize the second expectation as the KL divergence between our network model and the true distribution. The KL divergence is strictly positive, therefore we can view the first expectation as an evidence based lower bound (ELBO) of the true likelihood of the data. If we maximize the lower bound we will be also maximizing the probaility of the data. This totally reframes the problem of inference.

$$ \text{ELBO} = \mathbb{E}_{q(\mathbf z |\mathbf x, \mathbf \theta)}\bigg[\log\bigg( \frac{p(\mathbf x,\mathbf z)}{q(\mathbf z |\mathbf x, \mathbf \theta)}\bigg) \bigg] $$ $$ = \mathbb{E}_{q(\mathbf z |\mathbf x, \mathbf \theta)}\bigg[\log p(\mathbf x,\mathbf z) \bigg] - \mathbb{E}_{q(\mathbf z |\mathbf x, \mathbf \theta)}\bigg[\log q(\mathbf z |\mathbf x, \mathbf \theta)\bigg] \ $$ $$ = \mathbb{E}_{q(\mathbf z |\mathbf x, \mathbf \theta)}\bigg[\log p(\mathbf x| \mathbf z, \mathbf \phi)p(\mathbf z) \bigg] - \mathbb{E}_{q(\mathbf z |\mathbf x, \mathbf \theta)}\bigg[\log q(\mathbf z |\mathbf x, \mathbf \theta)\bigg] \ $$

We can then take $p(\mathbf z)$ as a prior distribution of our data. Let us choose that as a standard normal $\mathcal{N}(\mathbf z | \mathbf \mu = \mathbf 0, \mathbf \Sigma = \mathbf I)$. This lets us write our expression out as follows:

$$ \mathbb{E}_{q(\mathbf z |\mathbf x, \mathbf \theta)}\bigg[\log p(\mathbf x| \mathbf z, \mathbf \phi)\bigg] + \mathbb{E}_{q(\mathbf z |\mathbf x, \mathbf \theta)}\bigg[\log p(\mathbf z)\bigg] - \mathbb{E}_{q(\mathbf z |\mathbf x, \mathbf \theta)}\bigg[\log q(\mathbf z |\mathbf x, \mathbf \theta)\bigg] \ $$

$$ = \mathbb{E}_{q(\mathbf z |\mathbf x, \mathbf \theta)}\bigg[\log p(\mathbf x| \mathbf z, \mathbf \phi)\bigg] + \mathbb{E}_{q(\mathbf z |\mathbf x, \mathbf \theta)}\bigg[\log( p(\mathbf z)/q(\mathbf z |\mathbf x, \mathbf \theta))\bigg] $$

Then the second term corresponds to the negative KL divergence between the prior and our model. Remeber we asserted $q(\mathbf z |\mathbf x, \mathbf \theta)$ is normal. Making that part of the expression quite simple. We can then assert that the generative distribution is also gaussian. We can choose our network to just calculate the mean value. This makes the probability of a training set become the MSE loss.

recon_error = nn.MSELoss(reduction='sum')
def variational_loss(p_x,x,μ_qz,log_var_qz,ϵ,z):
    KL = -0.5 * torch.sum(log_var_qz - μ_qz.pow(2) - log_var_qz.exp()) # - KL Divergece between standard normal and current p(z) and q(z|x)
    return recon_error(p_x,x) + KL 

Now we need to discuss backpropigation within this network. When we are computing our loss function, it has to be taken as the expecation with respect to $q(\mathbf z |\mathbf x, \mathbf \theta)$. This means we must use a monte carlo estimate to evaluate the expectation as the integral is intractable. However this makes backpropigation impossible, we cant backprop through a random variable. Therefore we may reparameterize the randomness into some dummy variable $\epsilon$. Because our distribution is gaussian, we may simply choose $\epsilon$ as a standard normal and use the network output to rescale it. As for the network, we will keep it simple and just do three layers for the encoder and 3 for the deconder. See the implementation below.

class VAE(nn.Module):

    def __init__(self):
        # define encoder layers
        self.e_dense1 = nn.Linear(784,64)
        self.e_dense2 = nn.Linear(64, 32)
        self.e_dense3 = nn.Linear(32, 8)
        # define dencoder layers
        self.d_dense1 = nn.Linear(4,32)
        self.d_dense2 = nn.Linear(32, 64)
        self.d_dense3 = nn.Linear(64, 784)

    def encode(self, x):
        x = torch.relu(self.e_dense1(x))
        x = torch.relu(self.e_dense2(x))
        x = (self.e_dense3(x))
        return x[:,:4],x[:,4:]
    def sample(self,μ,log_var):
        σ = torch.exp(0.5*log_var)
        ϵ = torch.randn_like(σ) # Sample from epsilon
        return μ + ϵ.mul(σ),ϵ # transform epsilon into z sample
    def decode(self, x):
        x = torch.relu(self.d_dense1(x))
        x = torch.relu(self.d_dense2(x))
        x = torch.sigmoid(self.d_dense3(x))
        return x
    def forward(self,x):
        μ_qz,log_var_qz = self.encode(x)
        z,ϵ = self.sample(μ_qz,log_var_qz)
        p_x = self.decode(z)
        return p_x,μ_qz,log_var_qz,ϵ,z

Now we can instantiate our model and run it!!

vae = VAE()
optimizer = optim.Adam(vae.parameters())
criterion = variational_loss
for epoch in range(50):  # loop over the dataset multiple times

    total_loss = 0.0
    for i, data in enumerate(train_loader,start = 1):
        inputs = data[0].flatten(start_dim=1)

        # zero the parameter gradients
        # forward + backward + optimize
        p_x,μ_qz,log_var_qz,ϵ,z = vae(inputs)
        loss = criterion(p_x,inputs,μ_qz,log_var_qz,ϵ,z)
        total_loss += loss.item()
        if i % 100 == 0:
            print('Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, i * len(inputs), len(train_loader.dataset),
                (100*(len(inputs) * i) / len(train_loader.dataset)), loss.item() / len(inputs)))
    sample_idx = np.random.randint(len(mnist_trainset))
    μ,σ = vae.encode(torch.stack((mnist_trainset[sample_idx][0].flatten(),mnist_trainset[sample_idx][0].flatten())))
    z,ϵ = vae.sample(μ,σ)
    x = vae.decode(z.unsqueeze(dim=1))
    print('====> Epoch: {} Average loss: {:.4f}'.format(epoch, total_loss / len(train_loader.dataset)))
print('Finished Training')
We can visualize the latent space by evaluating the training samples and then save the resulting mean value for $\mathbf z$

Zx = [[] for i in range(10)]
Zy = [[] for i in range(10)]
for i in range(len(mnist_trainset)):
    μ,σ = vae.encode(torch.stack((mnist_trainset[i][0].flatten(),mnist_trainset[i][0].flatten())))
fig = plt.figure(figsize=(12, 12),dpi=80)

ax = fig.add_subplot()
N = 250

C = ['b','g','r','c','m','y','brown','black','orange','pink']
for i in range(10):
for i in range(0,10,-1):

plt.title("Latent Representation of Handwritten Digits")


We can also use this to to see how well it will produce a handwritten digit with a specified label.

for i in range(20):

    sample_idx = np.random.randint(len(mnist_trainset))
    while (mnist_trainset[sample_idx][1]) != 8:
        sample_idx = np.random.randint(len(mnist_trainset))
    μ,σ = vae.encode(torch.stack((mnist_trainset[sample_idx][0].flatten(),mnist_trainset[sample_idx][0].flatten())))
    z,ϵ = vae.sample(μ,σ)
    x = vae.decode(z.unsqueeze(dim=1))
    plt.title("Expected #"+ str(mnist_trainset[sample_idx][1]))






















