In this post, we will cover the complete implementation of Variational AutoEncoder, which can optimize the ELBO objective function. This is the summary of lecture "Probabilistic Deep Learning with Tensorflow 2" from Imperial College London.
import tensorflow as tf
import tensorflow_probability as tfp
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Ellipse
from IPython.display import HTML, Image
tfd = tfp.distributions
tfpl = tfp.layers
tfb = tfp.bijectors
plt.rcParams['figure.figsize'] = (10, 6)
plt.rcParams["animation.html"] = "jshtml"
plt.rcParams['animation.embed_limit'] = 2**128
print("Tensorflow Version: ", tf.__version__)
print("Tensorflow Probability Version: ", tfp.__version__)
Tensorflow Version: 2.5.0 Tensorflow Probability Version: 0.13.0
$ \text{latent variable } z \sim N(0, I) = p(z) \\ p(x \vert z) = \text{decoder}(z) \\ x \sim p(x \vert z) $
$ \text{encoder }(x) = q(z \vert x) \simeq p(z \vert x) \\ \begin{aligned} \log p(x) & \ge \mathbb{E}_{z \sim q(z \vert x)}[-\log q(z \vert x) + \log p(x \vert z)] \quad \leftarrow \text{maximizing this lower bound} \\ &= - \mathrm{KL} (q(z \vert x) \vert \vert p(z)) + \mathbb{E}_{z \sim q(z \vert x)}[\log p(x \vert z)] \quad \leftarrow \text{Evidence Lower Bound (ELBO)} \end{aligned}$
latent_size = 2
event_shape = (28, 28, 1)
encoder = Sequential([
Conv2D(8, (5, 5), strides=2, activation='tanh', input_shape=event_shape),
Conv2D(8, (5, 5), strides=2, activatoin='tanh'),
Flatten(),
Dense(64, activation='tanh'),
Dense(2 * latent_size),
tfpl.DistributionLambda(lambda t: tfd.MultivariateNormalDiag(
loc=t[..., :latent_size], scale_diag=tf.math.exp(t[..., latent_size:]))),
], name='encoder')
encoder(X_train[:16])
Almose reverse order of Encoder.
decoder = Sequential([
Dense(64, activation='tanh', input_shape=(latent_size, )),
Dense(128, activation='tanh'),
Reshape((4, 4, 8)), # In order to put it in the form required by Conv2D layer
Conv2DTranspose(8, (5, 5), strides=2, output_padding=1, activation='tanh'),
Conv2DTranspose(8, (5, 5), strides=2, output_padding=1, activation='tanh'),
Conv2D(1, (3, 3), padding='SAME'),
Flatten(),
tfpl.IndependentBernoulli(event_shape)
], name='decoder')
decoder(tf.random.normal([16, latent_size])
prior = tfd.MultivariateNormalDiag(loc=tf.zeros(latent_size))
One way to implement ELBO function is to use Analytical computation of KL divergence.
def loss_fn(X_true, approx_posterior, X_pred, prior_dist):
"""
X_true: batch of data examples
approx_posterior: the output of encoder
X_pred: output of decoder
prior_dist: Prior distribution
"""
return tf.reduce_mean(tfd.kl_divergence(approx_posterior, prior_dist) - X_pred.log_prob(X_true))
The other way is using Monte Carlo Sampling instead of analyticall with the KL Divergence.
def loss_fn(X_true, approx_posterior, X_pred, prior_dist):
reconstruction_loss = -X_pred.log_prob(X_true)
approx_posterior_sample = approx_posterior.sample()
kl_approx = (approx_posterior.log_prob(approx_posterior_sample) - prior_dist.log_prob(approx_posterior_sample))
return tf.reduce_mean(kl_approx + reconstruction_loss)
@tf.function
def get_loss_and_grads(x):
with tf.GradientTape() as tape:
approx_posterior = encoder(x)
approx_posterior_sample = approx_posterior.sample()
X_pred = decoder(approx_posterior_sample)
current_loss = loss_fn(x, approx_posterior, X_pred, prior)
grads = tape.gradient(current_loss, encoder.trainable_variables + decoder.trainable_variables)
return current_loss, grads
optimizer = tf.keras.optimizers.Adam()
for epoch in range(num_epochs):
for train_batch in train_data:
loss, grads = get_loss_and_grads(train_batch)
optimizer.apply_gradients(zip(grads, encoder.trainable_variables + decoder.trainable_variables))
z = prior.sample(1) # (1, 2)
x = decoder(z).sample() # (1, 28, 28, 1)
X_encoded = encoder(X_sample)
def vae(inputs):
approx_posterior = encoder(inputs)
decoded = decoder(approx_posterior.sample())
return decoded.sample()
reconstruction = vae(X_sample)
Review of terminology:
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Dense, Flatten, Reshape
# Import Fashion MNIST, make it a Tensorflow Dataset
(X_train, _), (X_test, _) = tf.keras.datasets.fashion_mnist.load_data()
X_train = X_train.astype('float32') / 255.
X_test = X_test.astype('float32') / 255.
example_X = X_test[:16]
batch_size = 64
X_train = tf.data.Dataset.from_tensor_slices(X_train).batch(batch_size)
# Define the encoding distribution, q(z | x)
latent_size = 2
event_shape = (28, 28)
encoder = Sequential([
Flatten(input_shape=event_shape),
Dense(256, activation='relu'),
Dense(128, activation='relu'),
Dense(64, activation='relu'),
Dense(32, activation='relu'),
Dense(2 * latent_size),
tfpl.DistributionLambda(
lambda t: tfd.MultivariateNormalDiag(
loc=t[..., :latent_size],
scale_diag=tf.math.exp(t[..., latent_size:])
)
)
])
WARNING:tensorflow:From /home/chanseok/anaconda3/envs/torch/lib/python3.7/site-packages/tensorflow_probability/python/distributions/distribution.py:346: calling MultivariateNormalDiag.__init__ (from tensorflow_probability.python.distributions.mvn_diag) with scale_identity_multiplier is deprecated and will be removed after 2020-01-01. Instructions for updating: `scale_identity_multiplier` is deprecated; please combine it into `scale_diag` directly instead.
# Pass an example image through the network - should return a batch of MultivariateNormalDiag
encoder(example_X)
<tfp.distributions.MultivariateNormalDiag 'sequential_distribution_lambda_MultivariateNormalDiag' batch_shape=[16] event_shape=[2] dtype=float32>
# Define the decoding distribution, p(x | z)
decoder = Sequential([
Dense(32, activation='relu'),
Dense(64, activation='relu'),
Dense(128, activation='relu'),
Dense(256, activation='relu'),
Dense(tfpl.IndependentBernoulli.params_size(event_shape)),
tfpl.IndependentBernoulli(event_shape)
])
# Pass a batch of examples to the decoder
decoder(tf.random.normal([16, latent_size]))
<tfp.distributions.Independent 'sequential_1_independent_bernoulli_IndependentBernoulli_Independentsequential_1_independent_bernoulli_IndependentBernoulli_Bernoulli' batch_shape=[16] event_shape=[28, 28] dtype=float32>
# Define the prior, p(z) - a standard bivariate Gaussian
prior = tfd.MultivariateNormalDiag(loc=tf.zeros(latent_size))
The loss function we need to estimate is $$ -\mathrm{ELBO} = \mathrm{KL}[ \ q(z|x) \ || \ p(z) \ ] - \mathrm{E}_{Z \sim q(z|x)}[\log p(x|Z)]\\ $$ where $x = (x_1, x_2, \ldots, x_n)$ refers to all observations, $z = (z_1, z_2, \ldots, z_n)$ refers to corresponding latent variables.
Assumed independence of examples implies that we can write this as $$ \sum_j \mathrm{KL}[ \ q(z_j|x_j) \ || \ p(z_j) \ ] - \mathrm{E}_{Z_j \sim q(z_j|x_j)}[\log p(x_j|Z_j)] $$
# Specify the loss function, an estimate of the -ELBO
def loss(x, encoding_dist, sampled_decoding_dist, prior):
return tf.reduce_sum(
tfd.kl_divergence(encoding_dist, prior) - sampled_decoding_dist.log_prob(x)
)
# Define a function that returns the loss and its gradients
@tf.function
def get_loss_and_grads(x):
with tf.GradientTape() as tape:
encoding_dist = encoder(x)
sampled_z = encoding_dist.sample()
sampled_decoding_dist = decoder(sampled_z)
current_loss = loss(x, encoding_dist, sampled_decoding_dist, prior)
grads = tape.gradient(current_loss, encoder.trainable_variables + decoder.trainable_variables)
return current_loss, grads
# Compile and train the model
num_epochs = 10
optimizer = tf.keras.optimizers.Adam()
for i in range(num_epochs):
for train_batch in X_train:
current_loss, grads = get_loss_and_grads(train_batch)
optimizer.apply_gradients(zip(grads, encoder.trainable_variables + decoder.trainable_variables))
print('-ELBO after epoch {}: {:.0f}'.format(i + 1, current_loss.numpy()))
-ELBO after epoch 1: 8990 -ELBO after epoch 2: 8858 -ELBO after epoch 3: 8782 -ELBO after epoch 4: 8820 -ELBO after epoch 5: 8716 -ELBO after epoch 6: 8664 -ELBO after epoch 7: 8727 -ELBO after epoch 8: 8667 -ELBO after epoch 9: 8810 -ELBO after epoch 10: 8675
# Connect encoder and decoder, compute a reconstruction
def vae(inputs):
approx_posterior = encoder(inputs)
decoding_dist = decoder(approx_posterior.sample())
return decoding_dist.sample()
example_reconstruction = vae(example_X).numpy().squeeze()
# Plot examples against reconstructions
f, axs = plt.subplots(2, 6, figsize=(16, 5))
for j in range(6):
axs[0, j].imshow(example_X[j, :, :].squeeze(), cmap='binary')
axs[1, j].imshow(example_reconstruction[j, :, :], cmap='binary')
axs[0, j].axis('off')
axs[1, j].axis('off')
Since the model has lack of reconstruction from grayscale image, So using mean for reconstruction gets more satisfied results.
# Connect encoder and decoder, compute a reconstruction with mean
def vae_mean(inputs):
approx_posterior = encoder(inputs)
decoding_dist = decoder(approx_posterior.sample())
return decoding_dist.mean()
example_reconstruction = vae_mean(example_X).numpy().squeeze()
# Plot examples against reconstructions
f, axs = plt.subplots(2, 6, figsize=(16, 5))
for j in range(6):
axs[0, j].imshow(example_X[j, :, :].squeeze(), cmap='binary')
axs[1, j].imshow(example_reconstruction[j, :, :], cmap='binary')
axs[0, j].axis('off')
axs[1, j].axis('off')
# Generate an example - sample a z value, then sample a reconstruction from p(x|z)
z = prior.sample(6)
generated_x = decoder(z).sample()
# Display generated_x
f, axs = plt.subplots(1, 6, figsize=(16, 5))
for j in range(6):
axs[j].imshow(generated_x[j, :, :].numpy().squeeze(), cmap='binary')
axs[j].axis('off')
# Generate an example - sample a z value, then sample a reconstruction from p(x|z)
z = prior.sample(6)
generated_x = decoder(z).mean()
# Display generated_x
f, axs = plt.subplots(1, 6, figsize=(16, 5))
for j in range(6):
axs[j].imshow(generated_x[j, :, :].numpy().squeeze(), cmap='binary')
axs[j].axis('off')
What if we use Monte Carlo Sampling for kl divergence?
encoder = Sequential([
Flatten(input_shape=event_shape),
Dense(256, activation='relu'),
Dense(128, activation='relu'),
Dense(64, activation='relu'),
Dense(32, activation='relu'),
Dense(2 * latent_size),
tfpl.DistributionLambda(
lambda t: tfd.MultivariateNormalDiag(
loc=t[..., :latent_size],
scale_diag=tf.math.exp(t[..., latent_size:])
)
)
])
decoder = Sequential([
Dense(32, activation='relu'),
Dense(64, activation='relu'),
Dense(128, activation='relu'),
Dense(256, activation='relu'),
Dense(tfpl.IndependentBernoulli.params_size(event_shape)),
tfpl.IndependentBernoulli(event_shape)
])
# Define the prior, p(z) - a standard bivariate Gaussian
prior = tfd.MultivariateNormalDiag(loc=tf.zeros(latent_size))
def loss(x, encoding_dist, sampled_decoding_dist, prior, sampled_z):
reconstruction_loss = -sampled_decoding_dist.log_prob(x)
kl_approx = (encoding_dist.log_prob(sampled_z) - prior.log_prob(sampled_z))
return tf.reduce_sum(kl_approx + reconstruction_loss)
@tf.function
def get_loss_and_grads(x):
with tf.GradientTape() as tape:
encoding_dist = encoder(x)
sampled_z = encoding_dist.sample()
sampled_decoding_dist = decoder(sampled_z)
current_loss = loss(x, encoding_dist, sampled_decoding_dist, prior, sampled_z)
grads = tape.gradient(current_loss, encoder.trainable_variables + decoder.trainable_variables)
return current_loss, grads
# Compile and train the model
num_epochs = 10
optimizer = tf.keras.optimizers.Adam()
for i in range(num_epochs):
for train_batch in X_train:
current_loss, grads = get_loss_and_grads(train_batch)
optimizer.apply_gradients(zip(grads, encoder.trainable_variables + decoder.trainable_variables))
print('-ELBO after epoch {}: {:.0f}'.format(i + 1, current_loss.numpy()))
-ELBO after epoch 1: 8914 -ELBO after epoch 2: 8802 -ELBO after epoch 3: 8799 -ELBO after epoch 4: 8743 -ELBO after epoch 5: 8790 -ELBO after epoch 6: 8716 -ELBO after epoch 7: 8787 -ELBO after epoch 8: 8686 -ELBO after epoch 9: 8650 -ELBO after epoch 10: 8813
# Connect encoder and decoder, compute a reconstruction with mean
def vae_mean(inputs):
approx_posterior = encoder(inputs)
decoding_dist = decoder(approx_posterior.sample())
return decoding_dist.mean()
example_reconstruction = vae_mean(example_X).numpy().squeeze()
# Plot examples against reconstructions
f, axs = plt.subplots(2, 6, figsize=(16, 5))
for j in range(6):
axs[0, j].imshow(example_X[j, :, :].squeeze(), cmap='binary')
axs[1, j].imshow(example_reconstruction[j, :, :], cmap='binary')
axs[0, j].axis('off')
axs[1, j].axis('off')
# Generate an example - sample a z value, then sample a reconstruction from p(x|z)
z = prior.sample(6)
generated_x = decoder(z).mean()
# Display generated_x
f, axs = plt.subplots(1, 6, figsize=(16, 5))
for j in range(6):
axs[j].imshow(generated_x[j, :, :].numpy().squeeze(), cmap='binary')
axs[j].axis('off')