This notebook is a very simple demonstration of Generative Adversarial Networks (GANs). We use the MNIST handwritten digit dataset and a dense MLP-style architecture for both the generator and descriminator. The training approach that we follow is similar to the early work on GANs by Goodfellow, et al.
Note that this is intended to be a simple domonstration only and does not include a number of advanced features that can dramatically improve the performance of GANs. See some of my other GAN notebooks for demonstrations of more advanced examples.
import matplotlib
%matplotlib inline
matplotlib.rcParams['figure.figsize'] = (11, 11)
import functools
import gzip
import operator
import struct
import tqdm.notebook as tqdm
import munch
import skimage as ski
import matplotlib.pyplot as plt
import numpy as np
import torch as th
This is just a Torch dataset for the standard MNIST data.
class MNIST(th.utils.data.Dataset):
'''Simple MNIST dataset containing only the images.
'''
def __init__(self):
'''Initialize a new MNIST dataset.
'''
super().__init__()
self.imgs = self._load_imgs()
@staticmethod
def _load_imgs():
'''Load all of the training images. This dataset is relatively small,
so we just keep everything in memory on the CPU side.
Thanks:
https://stackoverflow.com/questions/39969045/parsing-yann-lecuns-mnist-idx-file-format
'''
filename = '../data/torchvision/MNIST/raw/train-images-idx3-ubyte.gz'
with gzip.open(filename, mode='rb') as fh:
_magic, size = struct.unpack('>II', fh.read(8))
nrows, ncols = struct.unpack('>II', fh.read(8))
data = np.frombuffer(fh.read(), dtype=np.dtype(np.uint8).newbyteorder('>'))
imgs = data.reshape((size, nrows, ncols)) / 255.
return th.as_tensor(imgs, dtype=th.float32)
def __getitem__(self, idx):
'''Return a single image specified by the given index.
'''
return self.imgs[idx]
def __len__(self):
'''Return the number of images in the dataset.
'''
return self.imgs.shape[0]
def plot(self, idx):
'''Plot the image specified by the given index.
'''
fig, ax = plt.subplots()
ax.imshow(self.imgs[idx], cmap=plt.cm.gray)
ax.axis('off')
return munch.Munch(fig=fig, ax=ax)
def plot_montage(self, n=36):
'''Plot a montage of `n` randomly selected images.
'''
idxs = th.randperm(len(self))[:n]
imgs = self.imgs[idxs]
montage = ski.util.montage(imgs)
fig, ax = plt.subplots()
ax.imshow(montage, cmap=plt.cm.gray)
ax.axis('off')
return munch.Munch(fig=fig, ax=ax)
data = MNIST()
len(data)
60000
data.plot(3);
data.plot_montage();
Next, we define a generator that maps a flat, random-normal latent vector to an image with the specified size. The architecture we use here is just a dense, MLP-style network with hyperbolic tangent activations since we're just building a simple demonstration. Note that we do not use class labels for conditioning the generator in this example.
class _Tanh(th.nn.Module):
'''Hyperbolic tangent, scaled as described in Lecun's backprop tricks paper.
'''
def __init__(self):
super().__init__()
self.register_buffer('_a', th.as_tensor(1.7159), persistent=False)
self.register_buffer('_b', th.as_tensor(2. / 3.), persistent=False)
def forward(self, x):
return self._a * th.tanh(self._b * x)
class _Linear(th.nn.Linear):
'''Dense linear layer with linear Kaiming weight initialization.
'''
def reset_parameters(self):
th.nn.init.kaiming_normal_(self.weight, nonlinearity='linear')
if hasattr(self, 'bias'):
th.nn.init.zeros_(self.bias)
class Generator(th.nn.Module):
'''MLP-style generator.
'''
def __init__(self, latent_dim=32, img_size=(28, 28), layer_specs=(64, 128, 256)):
'''Initialize a new generator.
latent_dim (int):
Expected number of dimensions for the input latent vector
used for generating images.
img_size (tuple(int)):
Two-tuple of integers specifying the height and width of
the output images in pixels.
layer_specs (tuple(int)):
A tuple where each value specifies the number of hidden
units in each layer. The number of values in this tuple
determines the number of layers.
'''
super().__init__()
self.latent_dim = latent_dim
self.img_size = img_size
self.hidden = th.nn.Sequential()
layer_in, layer_out = None, latent_dim
for num_units in layer_specs:
layer_in, layer_out = layer_out, num_units
self.hidden.append(
th.nn.Sequential(
_Linear(layer_in, layer_out),
_Tanh(),
)
)
out_dim = functools.reduce(operator.mul, img_size)
layer_in, layer_out = layer_out, out_dim
self.visible = _Linear(layer_in, layer_out)
def forward(self, z):
'''Forward pass generates synthetic images from
the latent vector `z`.
'''
batch_size, latent_dim = z.shape
assert latent_dim == self.latent_dim
x = self.visible(self.hidden(z))
return th.sigmoid(x + 0.5).view(batch_size, *self.img_size)
@th.inference_mode()
def plot_montage(self, n=36, z=None):
'''Plot a montage of `n` synthetic images. If `z` is ``None`` then
the latent vectors will be drawn from the random normal distribution.
'''
if z is None:
z = th.randn(n, self.latent_dim)
imgs = self(z).cpu()
montage = ski.util.montage(imgs)
fig, ax = plt.subplots()
ax.imshow(montage, cmap=plt.cm.gray)
ax.axis('off')
return munch.Munch(fig=fig, ax=ax)
generator = Generator()
generator
Generator( (hidden): Sequential( (0): Sequential( (0): _Linear(in_features=32, out_features=64, bias=True) (1): _Tanh() ) (1): Sequential( (0): _Linear(in_features=64, out_features=128, bias=True) (1): _Tanh() ) (2): Sequential( (0): _Linear(in_features=128, out_features=256, bias=True) (1): _Tanh() ) ) (visible): _Linear(in_features=256, out_features=784, bias=True) )
# The number of parameters in our generator
sum(p.numel() for p in generator.parameters())
244944
# Make sure that we get reasonable outputs for a batch of latent input vectors
with th.inference_mode():
im = generator(th.randn(32, generator.latent_dim))
plt.imshow(im[0], cmap=plt.cm.gray);
plt.axis('off');
plt.colorbar(shrink=0.8);
plt.tight_layout();
im.shape, im.min(), im.max(), im.mean(), im.std()
(torch.Size([32, 28, 28]), tensor(0.0813), tensor(0.9768), tensor(0.6089), tensor(0.1613))
Now we define our discriminator. Again, we use an MLP-style network with a design that is very similar to our generator except that it takes images as it's inputs and outputs logits that indicate real or fake. We also inject the mean image across each bach into the network as a simple form of minibatch discrimination. This helps to prevent mode collapse.
class Discriminator(th.nn.Module):
'''MLP-style discriminator. This is a binary classifier, i.e.,
there is no class conditioning. This generator does, however,
use the mean images for each batch in order to perform a simple
version of "minibatch discrimination" that helps to prevent
mode collapse.
'''
def __init__(self, img_size=(28, 28), layer_specs=(256, 128, 64)):
'''Initialize a new discriminator.
img_size (tuple(int)):
Two-tuple of integers specifying the height and width of
the input images in pixels.
layer_specs (tuple(int)):
A tuple where each value specifies the number of hidden
units in each layer. The number of values in this tuple
determines the number of layers.
'''
super().__init__()
self.img_size = img_size
self.in_dim = functools.reduce(operator.mul, img_size)
self.hidden = th.nn.Sequential()
layer_in, layer_out = None, self.in_dim * 2
for num_units in layer_specs:
layer_in, layer_out = layer_out, num_units
self.hidden.append(
th.nn.Sequential(
_Linear(layer_in, layer_out),
_Tanh(),
)
)
layer_in, layer_out = layer_out, 1
self.visible = _Linear(layer_in, layer_out)
def forward(self, x):
'''Forward pass generates predicted class membership
probabilities for two classes: real and fake.
'''
# Means and variances are hard-coded for MNIST
x = (x - 0.1307) / 0.3081
x = x.flatten(1)
batch_size, in_dim = x.shape
assert in_dim == self.in_dim
# We inject the mean images for each batch. This is a
# simplistic form of minibatch discrimination that helps
# to prevent mode collapse.
x_mean = x.mean(dim=0, keepdim=True).expand(batch_size, -1)
x = th.cat((x, x_mean), dim=1)
return self.visible(self.hidden(x))
discriminator = Discriminator()
discriminator
Discriminator( (hidden): Sequential( (0): Sequential( (0): _Linear(in_features=1568, out_features=256, bias=True) (1): _Tanh() ) (1): Sequential( (0): _Linear(in_features=256, out_features=128, bias=True) (1): _Tanh() ) (2): Sequential( (0): _Linear(in_features=128, out_features=64, bias=True) (1): _Tanh() ) ) (visible): _Linear(in_features=64, out_features=1, bias=True) )
# The number of parameters in the discriminator
sum(p.numel() for p in discriminator.parameters())
442881
# Verify that we get a reasonable output
with th.inference_mode():
logits = discriminator(im)
logits.min(), logits.max(), logits.mean(), logits.std(), th.sigmoid(logits[0])
(tensor(-0.7343), tensor(0.3366), tensor(-0.2641), tensor(0.2429), tensor([0.4252]))
Now that the generator and discriminator are defined, we setup the training procedure. This is a very simple adversarial training method where the discriminator minimizes cross-entropy for labeling real images as real and fake images as fake. Simultaneously, the generator is trained produce fake images that the discriminator will label as real.
def train(generator, discriminator, data, epochs=350, batch_size=128, lr=0.0002,
lr_decay=0.99, label_smooth=0.9, device=0):
'''Simple, unconditioned adversarial training with cross-entropy loss.
generator (torch.Module):
The generator module should take a batch of random latent vectors as
inputs and produce a batch images as outputs.
discriminator (torch.Module):
The discriminator should take a batch of images as inputs and assign
a binary class labels, fake or real, as outputs.
data (torch.utils.data.Dataset):
A Torch dataset that produces greyscale images in HW format.
epochs (int):
The number of training epochs.
batch_size (int):
The batch size.
lr (float):
The learning rate to use.
lr_decay (float):
Exponential decay rate for the learning rate, applied after
each training epoch.
label_smooth (float):
The amount of one-sided label smoothing to apply to the
discriminator. This prevents very high confidence in
"real" labels, which helps to prevent very small gradients
that can slow training and result in mode collapse.
device (int):
The GPU device ID to use.
'''
generator.to(device)
discriminator.to(device)
generator.train()
discriminator.train()
g_opt = th.optim.RMSprop(generator.parameters(), alpha=0.999, lr=lr)
d_opt = th.optim.RMSprop(discriminator.parameters(), alpha=0.999, lr=lr)
g_sched = th.optim.lr_scheduler.ExponentialLR(g_opt, gamma=lr_decay)
d_sched = th.optim.lr_scheduler.ExponentialLR(d_opt, gamma=lr_decay)
g_losses, d_losses = [], []
dataloader = th.utils.data.DataLoader(
data,
batch_size=batch_size,
shuffle=True,
num_workers=4,
pin_memory=True,
drop_last=True,
)
for epoch in tqdm.trange(epochs):
for i, real_imgs in enumerate(dataloader):
real_imgs = real_imgs.to(device, non_blocking=True)
real_targs = th.ones(batch_size, 1, device=device)
fake_targs = th.zeros(batch_size, 1, device=device)
## Generator
z = th.randn(batch_size, generator.latent_dim, device=device)
fake_imgs = generator(z)
g_loss = th.nn.functional.binary_cross_entropy_with_logits(
discriminator(fake_imgs), real_targs)
g_losses.append(g_loss.item())
g_opt.zero_grad()
g_loss.backward()
g_opt.step()
## Discriminator
with th.no_grad():
z = th.randn(batch_size, generator.latent_dim, device=device)
fake_imgs = generator(z)
fake_loss = th.nn.functional.binary_cross_entropy_with_logits(
discriminator(fake_imgs), fake_targs)
real_loss = th.nn.functional.binary_cross_entropy_with_logits(
discriminator(real_imgs), real_targs * label_smooth)
d_loss = (fake_loss + real_loss) / 2.
d_losses.append(d_loss.item())
d_opt.zero_grad()
d_loss.backward()
d_opt.step()
g_sched.step()
d_sched.step()
generator.eval()
discriminator.eval()
generator.cpu()
discriminator.cpu()
return munch.Munch(
g_losses=g_losses,
d_losses=d_losses,
)
result = train(generator, discriminator, data, device=1)
0%| | 0/350 [00:00<?, ?it/s]
Here we see the loss for both the generator and discriminator. Things look fairly stable, which is good. The generator typically has higher loss, meaning that the generator is somewhat "chasing" the discriminator. The variation in the loss settles down over time because of the learning rate scheduler. These observations are generally good but it is often difficult to tell how well things are working by analyzing the losses alone. Tracking a metric like FID would likely be helpful here.
plt.plot(result.g_losses);
plt.title('Generator Loss During Training');
plt.xlabel('Training Step');
plt.ylabel('Loss');
plt.plot(result.d_losses);
plt.title('Discriminator Loss During Training');
plt.xlabel('Training Step');
plt.ylabel('Loss');
Next, we see some images produced by our generator. The images are a bit messy but we can clearly identify which digit is represented. Many of these images appear to be believable as genuine.
generator.plot_montage();
generator.plot_montage();
Next, we experiment with performing a linear interpolation between two latent space vectors and passing these interpolated latent vectors through our generator. This shows how the images evolve from one digit to another as the latent-space input vectors vary.
z0 = th.randn(1, generator.latent_dim)
z1 = th.randn(1, generator.latent_dim)
n = 36
alpha = th.linspace(0., 1., n)[:, None]
z = alpha * z0 + (1. - alpha) * z1
generator.plot_montage(n=n, z=z);
z0 = th.randn(1, generator.latent_dim)
z1 = th.randn(1, generator.latent_dim)
n = 36
alpha = th.linspace(0., 1., n)[:, None]
z = alpha * z0 + (1. - alpha) * z1
generator.plot_montage(n=n, z=z);
z0 = th.randn(1, generator.latent_dim)
z1 = th.randn(1, generator.latent_dim)
n = 36
alpha = th.linspace(0., 1., n)[:, None]
z = alpha * z0 + (1. - alpha) * z1
generator.plot_montage(n=n, z=z);
z0 = th.randn(1, generator.latent_dim)
z1 = th.randn(1, generator.latent_dim)
n = 36
alpha = th.linspace(0., 1., n)[:, None]
z = alpha * z0 + (1. - alpha) * z1
generator.plot_montage(n=n, z=z);
And that concludes my simple demonstration of unconditioned GANs on the MNIST dataset! Although this approach seems to work, we can definitely do better. Check out my next notebook on conditioned GANs with MNIST in order to see how adding in class label information improves performance.