Generative Adversarial Networks, introduced by Ian Goodfellow in 2014, are neural nets we can train to produce new images (or other kinds of data) that look as though they came from our true data distribution. In this notebook, we'll implement a small GAN for generating images that look as though they come from the MNIST dataset.
The key insight behind the GAN is to pit two neural networks against each other. On the one hand is the Generator, a neural network that takes random noise as input and produces an image as output. On the other hand is the Discriminator, which takes in an image and classifies it as real (from MNIST) or fake (from our Generator). During training, we alternate between training the Generator to fool the Discriminator, and training the Discriminator to call the Generator's bluff.
Implementing a GAN in Tensorflow will give you practice turning more involved models into working code, and is also a great showcase for Tensorflow's variable scope feature. (Variable scope has made cameos in previous tutorials, but we'll discuss it in a bit more depth here. If you want to see how variable scope is used in TensorFlow Slim, definitely go revisit Kevin Liang's VAE tutorial!)
%matplotlib inline
import tensorflow as tf
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import time
# Use if running on a GPU
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
config.log_device_placement = True
As in previous examples, we'll use MNIST, because it's a small and easy-to-use dataset that comes bundled with Tensorflow.
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
Extracting MNIST_data/train-images-idx3-ubyte.gz Extracting MNIST_data/train-labels-idx1-ubyte.gz Extracting MNIST_data/t10k-images-idx3-ubyte.gz Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
Let's define some utility functions that will help us quickly construct layers for use in our model. There are two things worth noting here:
Instead of tf.Variable
, we use tf.get_variable
.
The reason for this is a bit subtle, and you may want to skip this and come back to it once you've seen the rest of the code. Here's the basic explanation. Later on in this notebook, we will call fully_connected_layer
from a couple different places. Sometimes, we will want new variables to be added to the graph, because we are creating an entirely new layer of our network. Other times, however, we will want to use the same weights as an already-existing layer, but acting on different inputs.
For example, the Discriminator network will appear twice in our computational graph; in one case, the input neurons will be connected to the "real data" placeholder (which we will feed MNIST images), and in the other, they will be connected to the output of the Generator. Although these networks form two separate parts of our computational graph, we want them to share the same weights: conceptually, there is one Discriminator function that gets applied twice, not two different functions altogether. Since tf.Variable
always creates a new variable when called, it would not be appropriate for use here.
Variable scoping solves this problem. Whenever we are adding nodes to a graph, we are operating within a scope. Scopes can be named, and you can create a new scope using tf.variable_scope('name')
(more on this later). When a scope is open, it can optionally be in reuse mode. The result of calling tf.get_variable
depends on whether you are in reuse mode or not. If not (this is the default), tf.get_variable
will create a new variable, or cause an error if a variable by the same name already exists in the current scope. If you are in reuse mode, the behavior is the opposite: tf.get_variable
will look up and return an existing variable (with the specified name) within your scope, or throw an error if it doesn't exist. By carefully controlling our scopes later on, we can create exactly the graph we want, with variables shared across the graph where appropriate.
The variables_from_scope
function lists all variables created within a given scope. This will be useful later, when we want to update all "discriminator" variables, but no "generator" variables, or vice versa.
def shape(tensor):
"""
Get the shape of a tensor. This is a compile-time operation,
meaning that it runs when building the graph, not running it.
This means that it cannot know the shape of any placeholders
or variables with shape determined by feed_dict.
"""
return tuple([d.value for d in tensor.get_shape()])
def fully_connected_layer(in_tensor, out_units, activation_function=tf.nn.relu):
"""
Add a fully connected layer to the default graph, taking as input `in_tensor`, and
creating a hidden layer of `out_units` neurons. This should be called within a unique variable
scope. Creates variables W and b, and computes activation_function(in * W + b).
"""
_, num_features = shape(in_tensor)
W = tf.get_variable("weights", [num_features, out_units], initializer=tf.truncated_normal_initializer(stddev=0.1))
b = tf.get_variable("biases", [out_units], initializer=tf.constant_initializer(0.1))
return activation_function(tf.matmul(in_tensor, W) + b)
def variables_from_scope(scope_name):
"""
Returns a list of all variables in a given scope. This is useful when
you'd like to back-propagate only to weights in one part of the network
(in our case, the generator or the discriminator).
"""
return tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=scope_name)
We'll also provide a simple function for displaying a few 28-pixel images. This will help us understand the progress of our GAN as it trains; we'll use it to visualize the generated 'fake digit' images.
def visualize_row(images, img_width=28, cmap='gray'):
"""
Takes in a tensor of images of given width, and displays them in a column
in a plot, using `cmap` to map from numbers to colors.
"""
im = np.reshape(images, [-1, img_width])
plt.figure()
plt.imshow(im, cmap=cmap)
plt.show()
A GAN is made up of two smaller networks: a generator and a discriminator. The generator is responsible for sampling images from a distribution that we hope will get closer and closer, as we train, to the real data distribution.
Neural networks are deterministic, so in order to sample a new image from the generator, we first create some random noise z
(in our case, z
will be a 100-dimensional uniform random variable) and then feed that noise to the network. You can think of z
as being a latent, low-dimensional representation of some image G(z)
, though in a vanilla GAN, it is usually difficult to interpret z
's components in a meaningful way.
Our generator is a dead-simple multi-layer perceptron (feed-forward network), with 128 hidden units.
def generator(z):
"""
Given random noise `z`, use a simple MLP with 128 hidden units to generate a
sample image (784 values between 0 and 1, enforced with the sigmoid function).
"""
with tf.variable_scope("fc1"):
fc1 = fully_connected_layer(z, 128)
with tf.variable_scope("fc2"):
return fully_connected_layer(fc1, 784, activation_function=tf.sigmoid)
Although it isn't necesssary, it makes some sense for our discriminator to mirror the generator's architecture, as we do here. The discriminator takes in an image (perhaps a real one from the MNIST dataset, perhaps a fake one from our generator), and attempts to classify it as real (1) or fake (0). Our architecture is again a simple MLP, taking 784 pixels down to 128 hidden units, and finally down to a probability.
def discriminator(x):
"""
This discriminator network takes in a tensor with shape [batch, 784], and classifies
each example image as real or fake. The network it uses is quite simple: a fully connected
layer with ReLU activation takes us down to 128 dimensions, then we collapse that to 1 number
in [0, 1] using a fully-connected layer with sigmoid activation. The result can be interpreted
as a probability, the discriminator's strength-of-belief that a sample is from the
real data distribution.
"""
with tf.variable_scope("fc1"):
fc1 = fully_connected_layer(x, 128)
with tf.variable_scope("fc2"):
return fully_connected_layer(fc1, 1, activation_function=tf.sigmoid)
Given a generator and discriminator, we can now set up the GAN's computational graph.
We use Tensorflow's variable scope feature for two purposes.
First, it helps separate the variables used by the generator and by the discriminator; this is important, because when training, we want to alternate between updating each set of variables according to a different objective.
Second, scoping helps us reuse the same set of discriminator weights both for the operations we perform on real images and for those performed on fake images. To achieve this, after calling discriminator
for the first time (and creating these weight variables), we tell our current scope to reuse_variables()
, meaning that on our next call to discriminator
, existing variables will be reused rather than creating new ones.
def gan(batch_size, z_dim):
"""
Given some details about the training procedure (batch size, dimension of z),
this function sets up the rest of the computational graph for the GAN.
It returns a dictionary containing six ops/tensors: `train_d` and `train_g`, the
optimization steps for the discriminator and generator, `real_data` and `noise`,
two placeholders that should be fed in during training, `d_loss`, the discriminator loss
(useful for estimating progress toward convergence), and `fake_data`, which can be
evaluated (with noise in the feed_dict) to sample from the generator's distribution.
"""
z = tf.placeholder(tf.float32, [batch_size, z_dim], name='z')
x = tf.placeholder(tf.float32, [batch_size, 784], name='x')
with tf.variable_scope('generator'):
fake_x = generator(z)
with tf.variable_scope('discriminator') as scope:
d_on_real = discriminator(x)
scope.reuse_variables()
d_on_fake = discriminator(fake_x)
g_loss = -tf.reduce_mean(tf.log(d_on_fake))
d_loss = -tf.reduce_mean(tf.log(d_on_real) + tf.log(1. - d_on_fake))
optimize_d = tf.train.AdamOptimizer().minimize(d_loss, var_list=variables_from_scope("discriminator"))
optimize_g = tf.train.AdamOptimizer().minimize(g_loss, var_list=variables_from_scope("generator"))
return {'train_d': optimize_d,
'train_g': optimize_g,
'd_loss': d_loss,
'fake_data': fake_x,
'real_data': x,
'noise': z}
Our training procedure is a bit more involved than in past demos. Here are the main differences:
np.random.uniform
function.def train_gan(iterations, batch_size=50, z_dim=100):
"""
Construct and train the GAN.
"""
model = gan(batch_size=batch_size, z_dim=z_dim)
def make_noise():
return np.random.uniform(-1.0, 1.0, [batch_size, z_dim])
def next_feed_dict():
return {model['real_data']: mnist.train.next_batch(batch_size)[0],
model['noise']: make_noise()}
initialize_all = tf.global_variables_initializer()
with tf.Session(config=config) as sess:
sess.run(initialize_all)
start_time = time.time()
for t in range(iterations):
sess.run(model['train_g'], feed_dict=next_feed_dict())
_, d_loss = sess.run([model['train_d'], model['d_loss']], feed_dict=next_feed_dict())
if t % 1000 == 0 or t+1 == iterations:
fake_data = sess.run(model['fake_data'], feed_dict={model['noise']: make_noise()})
print('Iter [%8d] Time [%5.4f] d_loss [%.4f]' % (t, time.time() - start_time, d_loss))
visualize_row(fake_data[:5])
It's time to run our GAN! Watch as it learns to draw recognizable digits in about three minutes.
train_gan(25000)
Iter [ 0] Time [0.4396] d_loss [1.7320]
Iter [ 1000] Time [11.4322] d_loss [0.0058]
Iter [ 2000] Time [22.3807] d_loss [0.0151]
Iter [ 3000] Time [33.0340] d_loss [0.1149]
Iter [ 4000] Time [43.4463] d_loss [0.1342]
Iter [ 5000] Time [54.4511] d_loss [0.2168]
Iter [ 6000] Time [65.2802] d_loss [0.3687]
Iter [ 7000] Time [75.8429] d_loss [0.5434]
Iter [ 8000] Time [86.5826] d_loss [0.6736]
Iter [ 9000] Time [93.5509] d_loss [0.4509]
Iter [ 10000] Time [96.8782] d_loss [0.6090]
Iter [ 11000] Time [100.2687] d_loss [0.7614]
Iter [ 12000] Time [103.6453] d_loss [0.5020]
Iter [ 13000] Time [107.0101] d_loss [0.4292]
Iter [ 14000] Time [110.2594] d_loss [0.7531]
Iter [ 15000] Time [113.6600] d_loss [0.8509]
Iter [ 16000] Time [116.9783] d_loss [0.5898]
Iter [ 17000] Time [120.3468] d_loss [0.6821]
Iter [ 18000] Time [123.8777] d_loss [0.5784]
Iter [ 19000] Time [127.1924] d_loss [0.9566]
Iter [ 20000] Time [134.0210] d_loss [1.0093]
Iter [ 21000] Time [144.2357] d_loss [0.8924]
Iter [ 22000] Time [154.9471] d_loss [0.8893]
Iter [ 23000] Time [166.1029] d_loss [0.9184]
Iter [ 24000] Time [177.5164] d_loss [0.7486]
Iter [ 24999] Time [188.8372] d_loss [0.7503]