# adapted from https://github.com/keras-team/keras/blob/master/examples/variational_autoencoder.py
%matplotlib inline
import numpy as np
np.random.seed(0)
import matplotlib.pyplot as plt
import keras
from keras import backend as K
from keras.models import Model, Sequential
from keras.layers import Input, Dense, Lambda
from keras.losses import binary_crossentropy
Using TensorFlow backend.
# from keras.datasets import mnist
# (x_train, y_train), (x_test, y_test) = mnist.load_data()
mnist = np.load('data/mnist.npz')
(x_train, y_train), (x_test, y_test) = (mnist['x_train'], mnist['y_train']), (mnist['x_test'], mnist['y_test'])
J = 2 # dimension of the latent space
D = 784 # dim of input space
# pre-processing
x_train = np.reshape(x_train, [-1, D])
x_test = np.reshape(x_test, [-1, D])
x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255
# Let the prior over latent variables p(z) be J-dimensional standard Gaussian distribution; let the probabilistic
# decoder p_theta(x|z) be D-dimensional multivariate Bernoulli with independent dimensions (equation 11 in Appendix
# C.1); let the encoder q_phi(z|x) be J-dimensional Gaussian with diagonal covariance (equation 12 in Appendix C.1)
# so that the KL divergence term in ELBO is given in closed-form (Appendix B)
# encoder model; map x to the parameters (phi_mu, phi_logvar) of the approximate posterior q_phi(z|x), as well as a
# sample z from q_phi(z|x); we only use a single sample (L=1) as in the paper for the MCMC approximation (sec 2.3)
x_input = Input((D,), name='x') # tensor shape (?, 786), where ? is the size of minibatch
z_h_dim = 500 # size of the first hidden layer of the encoder network
z_h = Dense(z_h_dim, activation='tanh')(x_input) # just like eq (12), C.2
phi_mu = Dense(J, name='phi_mu')(z_h)
phi_logvar = Dense(J, name='phi_logvar')(z_h)
def sample_z(args):
# use reparameterization trick to obtain a sample from q_phi(z|x^i), for every example x^i in the minibatch
# outputs z with shape (?, z_dim), where ? stands for minibatch size
phi_mu, phi_logvar = args
phi_var = K.exp(phi_logvar)
phi_sig = phi_var ** 0.5
epsilon = K.random_normal(shape=K.shape(phi_sig)) # K.shape: https://github.com/keras-team/keras/issues/5211
z = phi_mu + phi_sig * epsilon
return z # (?, J)
z = Lambda(sample_z, output_shape=(J,), name='sample_z')([phi_mu, phi_logvar]) # have to wrap this tensor in a layer for Model
# call to work; see https://github.com/keras-team/keras/issues/6263 the function being wrapped must take 1 arg
encoder = Model(inputs=x_input, outputs=z, name='encoder')
encoder.summary()
__________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== x (InputLayer) (None, 784) 0 __________________________________________________________________________________________________ dense_1 (Dense) (None, 500) 392500 x[0][0] __________________________________________________________________________________________________ phi_mu (Dense) (None, 2) 1002 dense_1[0][0] __________________________________________________________________________________________________ phi_logvar (Dense) (None, 2) 1002 dense_1[0][0] __________________________________________________________________________________________________ sample_z (Lambda) (None, 2) 0 phi_mu[0][0] phi_logvar[0][0] ================================================================================================== Total params: 394,504 Trainable params: 394,504 Non-trainable params: 0 __________________________________________________________________________________________________
# decoder model; map z to the parameters theta_mu of the likelihood p_theta(x|z), in this case a multi-variate
# Bernoulli
z_input = Input((J,), name='z')
x_h_dim = 500 # size of the first hidden layer of the decoder network
x_h = Dense(x_h_dim, activation='tanh')(z_input) # eq (11), C.1
y = Dense(D, activation='sigmoid', name='x_logits')(x_h) # eq (11), C.1; logits for image pixels being on/off, (?, 784)
decoder = Model(inputs=z_input, outputs=y, name='decoder')
decoder.summary()
_________________________________________________________________ Layer (type) Output Shape Param # ================================================================= z (InputLayer) (None, 2) 0 _________________________________________________________________ dense_2 (Dense) (None, 500) 1500 _________________________________________________________________ x_logits (Dense) (None, 784) 392784 ================================================================= Total params: 394,284 Trainable params: 394,284 Non-trainable params: 0 _________________________________________________________________
# build VAE as one Keras model
# x_output = decoder(encoder(x_input)) # outputs are logits
x_output = decoder(encoder(x_input)) # outputs are logits
vae = Model(inputs=x_input, outputs=x_output)
vae.summary()
_________________________________________________________________ Layer (type) Output Shape Param # ================================================================= x (InputLayer) (None, 784) 0 _________________________________________________________________ encoder (Model) (None, 2) 394504 _________________________________________________________________ decoder (Model) (None, 784) 394284 ================================================================= Total params: 788,788 Trainable params: 788,788 Non-trainable params: 0 _________________________________________________________________
# loss
xent_loss = D * binary_crossentropy(x_input, x_output) # vector of shape (?); rescale by D, because
# 'binary_crossentropy' calculates the mean across the pixels, but we want the sum
kl_loss = -0.5 * K.sum(1 + phi_logvar - phi_mu **2 - K.exp(phi_logvar), axis=-1) # vector of shape (?); we're
# minimizing -ELBO, hence negated
loss = K.mean(xent_loss + kl_loss) # make it batch size independent
vae.add_loss(loss)
vae.compile(loss=None, optimizer='adadelta') # if we want to use the loss kwarg in model.compile, 'loss' would
# need to be written as a function that takes (x_input, x_output) as arg, instead of a custom op
/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:2: UserWarning: Output "decoder" missing from loss dictionary. We assume this was done on purpose, and we will not be expecting any data to be passed to "decoder" during training.
epochs = 40
batch_size = 100
vae.fit(x_train,
epochs=epochs,
batch_size=batch_size,
validation_data=(x_test, None))
Train on 60000 samples, validate on 10000 samples Epoch 1/40 60000/60000 [==============================] - 7s 119us/step - loss: 200.7443 - val_loss: 183.9392 Epoch 2/40 60000/60000 [==============================] - 6s 102us/step - loss: 181.1966 - val_loss: 178.0154 Epoch 3/40 60000/60000 [==============================] - 6s 102us/step - loss: 177.0143 - val_loss: 175.3122 Epoch 4/40 24300/60000 [===========>..................] - ETA: 3s - loss: 174.197960000/60000 [==============================] - 6s 102us/step - loss: 173.1681 - val_loss: 170.7559 Epoch 5/40 60000/60000 [==============================] - 6s 102us/step - loss: 169.7557 - val_loss: 168.5747 Epoch 6/40 60000/60000 [==============================] - 6s 102us/step - loss: 167.4907 - val_loss: 165.6503 Epoch 7/40 60000/60000 [==============================] - 6s 101us/step - loss: 165.8160 - val_loss: 165.5939 Epoch 8/40 60000/60000 [==============================] - 6s 102us/step - loss: 164.5859 - val_loss: 165.5661 Epoch 9/40 60000/60000 [==============================] - 6s 102us/step - loss: 163.4062 - val_loss: 163.7000 Epoch 10/40 60000/60000 [==============================] - 6s 101us/step - loss: 162.4083 - val_loss: 161.6195 Epoch 11/40 49400/60000 [=======================>......] - ETA: 1s - loss: 161.122360000/60000 [==============================] - 6s 101us/step - loss: 161.1199 - val_loss: 160.4683 Epoch 12/40 60000/60000 [==============================] - 6s 102us/step - loss: 160.0924 - val_loss: 160.5189 Epoch 13/40 60000/60000 [==============================] - 6s 102us/step - loss: 158.8696 - val_loss: 158.8651 Epoch 14/40 60000/60000 [==============================] - 6s 101us/step - loss: 158.4244 - val_loss: 157.7781 Epoch 15/40 2300/60000 [>.............................] - ETA: 5s - loss: 157.194660000/60000 [==============================] - 6s 102us/step - loss: 157.6936 - val_loss: 157.2325 Epoch 16/40 60000/60000 [==============================] - 6s 102us/step - loss: 157.3197 - val_loss: 157.5854 Epoch 17/40 60000/60000 [==============================] - 6s 102us/step - loss: 156.6138 - val_loss: 155.9665 Epoch 18/40 54600/60000 [==========================>...] - ETA: 0s - loss: 156.750560000/60000 [==============================] - 6s 102us/step - loss: 156.7074 - val_loss: 155.4256 Epoch 19/40 60000/60000 [==============================] - 6s 101us/step - loss: 155.8403 - val_loss: 156.6118 Epoch 20/40 60000/60000 [==============================] - 6s 103us/step - loss: 155.6840 - val_loss: 157.0905 Epoch 21/40 60000/60000 [==============================] - 6s 101us/step - loss: 155.2082 - val_loss: 159.2843 Epoch 22/40 3100/60000 [>.............................] - ETA: 5s - loss: 153.700860000/60000 [==============================] - 6s 102us/step - loss: 155.2330 - val_loss: 154.5763 Epoch 23/40 60000/60000 [==============================] - 6s 101us/step - loss: 154.4578 - val_loss: 156.1153 Epoch 24/40 60000/60000 [==============================] - 6s 102us/step - loss: 154.2349 - val_loss: 154.1508 Epoch 25/40 55200/60000 [==========================>...] - ETA: 0s - loss: 154.238560000/60000 [==============================] - 6s 102us/step - loss: 154.2291 - val_loss: 152.9378 Epoch 26/40 60000/60000 [==============================] - 6s 101us/step - loss: 153.7390 - val_loss: 153.1745 Epoch 27/40 60000/60000 [==============================] - 6s 102us/step - loss: 153.8308 - val_loss: 153.0685 Epoch 28/40 60000/60000 [==============================] - 6s 102us/step - loss: 153.3736 - val_loss: 153.4663 Epoch 29/40 2900/60000 [>.............................] - ETA: 5s - loss: 151.109760000/60000 [==============================] - 6s 101us/step - loss: 152.9898 - val_loss: 152.3988 Epoch 30/40 60000/60000 [==============================] - 6s 102us/step - loss: 152.8944 - val_loss: 152.6078 Epoch 31/40 60000/60000 [==============================] - 6s 101us/step - loss: 152.9080 - val_loss: 152.6174 Epoch 32/40 56500/60000 [===========================>..] - ETA: 0s - loss: 152.183360000/60000 [==============================] - 6s 101us/step - loss: 152.1455 - val_loss: 154.7331 Epoch 33/40 60000/60000 [==============================] - 6s 102us/step - loss: 152.7390 - val_loss: 152.5302 Epoch 34/40 60000/60000 [==============================] - 6s 102us/step - loss: 152.1273 - val_loss: 153.2042 Epoch 35/40 60000/60000 [==============================] - 6s 101us/step - loss: 151.8123 - val_loss: 156.1782 Epoch 36/40 2900/60000 [>.............................] - ETA: 5s - loss: 165.236360000/60000 [==============================] - 6s 101us/step - loss: 152.3245 - val_loss: 155.5324 Epoch 37/40 60000/60000 [==============================] - 6s 100us/step - loss: 151.8567 - val_loss: 153.2496 Epoch 38/40 60000/60000 [==============================] - 6s 101us/step - loss: 151.8368 - val_loss: 151.1960 Epoch 39/40 56300/60000 [===========================>..] - ETA: 0s - loss: 151.456760000/60000 [==============================] - 6s 101us/step - loss: 151.3798 - val_loss: 153.0719 Epoch 40/40 60000/60000 [==============================] - 6s 102us/step - loss: 151.5535 - val_loss: 152.6939
<keras.callbacks.History at 0x7f2369193898>
# visualize the posterior distributions q_phi(z|x) on test samples
x_test_encoded = encoder.predict(x_test, batch_size=batch_size)
plt.figure(figsize=(6, 6))
# cmap = plt.cm.rainbow
cmap = plt.cm.get_cmap('rainbow', 10) # use discrete colors; https://stackoverflow.com/questions/14777066
plt.scatter(x_test_encoded[:, 0], x_test_encoded[:, 1], c=y_test, cmap=cmap)
plt.xlabel("z[0]")
plt.ylabel("z[1]")
plt.colorbar()
plt.show()
# visualize the decoder distributions p_theta(x|z) for points on a grid in the
# latent space
# display a nxn 2D manifold of digits
n = 20
digit_size = 28
figure = np.zeros((digit_size * n, digit_size * n))
# linearly spaced coordinates corresponding to the 2D plot
# of digit classes in the latent space
lims = [-5, 5]
grid_x = np.linspace(*lims, n)
grid_y = np.linspace(*lims, n)[::-1]
for i, yi in enumerate(grid_y):
for j, xi in enumerate(grid_x):
z_sample = np.array([[xi, yi]])
x_decoded = decoder.predict(z_sample)
digit = x_decoded[0].reshape(digit_size, digit_size)
figure[i * digit_size: (i + 1) * digit_size,
j * digit_size: (j + 1) * digit_size] = digit
plt.figure(figsize=(10, 10))
start_range = digit_size // 2
end_range = n * digit_size + start_range + 1
pixel_range = np.arange(start_range, end_range, digit_size)
sample_range_x = np.round(grid_x, 1)
sample_range_y = np.round(grid_y, 1)
plt.xticks(pixel_range, sample_range_x)
plt.yticks(pixel_range, sample_range_y)
plt.xlabel("z[0]")
plt.ylabel("z[1]")
plt.imshow(figure, cmap='Greys_r')
plt.show()