The vanilla GAN we have seen does not take into account the class of an image; thus, while it can produce realistic images that look like images from the training set, it cannot produce images of a specific, desired class, such as only images of cats. We can extend the vanilla GAN by incorporating extra conditioning information, for example to specify which class we want it to generate. To do this, we pass a conditioning vector $c$ to both the generator and discriminator, so they become $G(z, c)$ and $D(x, c)$, respectively. Then the generator $G$ can take into account which class it is supposed to generate, and $D$ can score how realistic a generated image is, given that it is intended to represent a specific class.
Then, the objective function of the discriminator becomes: $$ J_D = \mathbb{E}_{x \sim p_{data}} [ - \log D(x, c) ] + \mathbb{E}_{z \sim q(z)} [ - \log(1 - D(G(z, c), c)) ] $$
And the objective function of the generator becomes: $$ J_G = \mathbb{E}_{z \sim q(z)} [ - \log D(G(z, c), c) ] $$
import os
import math
import random
import itertools
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable
from IPython import display
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
dataroot = 'data'
workers = 2
batchSize = 64
imageSize = 64
nz = 100
ngf = 64
ndf = 64
nepochs = 40
lr = 0.0002
beta1 = 0.5
manualSeed = random.randint(1, 10000)
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)
torch.cuda.manual_seed_all(manualSeed)
('Random Seed: ', 5567)
dataset = dset.CIFAR10(root=dataroot, download=True,
transform=transforms.Compose([
transforms.Scale(imageSize),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]))
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batchSize,
shuffle=True, num_workers=workers)
Downloading http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz
nc = 3
# custom weights initialization called on netG and netD
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.model = nn.Sequential(
# input is Z, going into a convolution
nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 8),
nn.ReLU(True),
# state size. (ngf*8) x 4 x 4
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
# state size. (ngf*4) x 8 x 8
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
# state size. (ngf*2) x 16 x 16
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
# state size. (ngf) x 32 x 32
nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
nn.Tanh()
# state size. (nc) x 64 x 64
)
def forward(self, x):
return self.model(x)
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
# input is (nc) x 64 x 64
nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
# nn.LeakyReLU(0.2, inplace=True),
nn.LeakyReLU(0.2),
# state size. (ndf) x 32 x 32
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 2),
# nn.LeakyReLU(0.2, inplace=True),
nn.LeakyReLU(0.2),
# state size. (ndf*2) x 16 x 16
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 4),
# nn.LeakyReLU(0.2, inplace=True),
nn.LeakyReLU(0.2),
# state size. (ndf*4) x 8 x 8
nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 8),
# nn.LeakyReLU(0.2, inplace=True),
nn.LeakyReLU(0.2),
# state size. (ndf*8) x 4 x 4
nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
)
def forward(self, x):
output = self.model(x)
return output.view(-1, 1)
G = Generator()
G.apply(weights_init)
print(G)
D = Discriminator()
D.apply(weights_init)
print(D)
Generator ( (model): Sequential ( (0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False) (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) (2): ReLU (inplace) (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False) (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True) (5): ReLU (inplace) (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False) (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True) (8): ReLU (inplace) (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False) (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) (11): ReLU (inplace) (12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False) (13): Tanh () ) ) Discriminator ( (model): Sequential ( (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False) (1): LeakyReLU (0.2) (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False) (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True) (4): LeakyReLU (0.2) (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False) (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True) (7): LeakyReLU (0.2) (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False) (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) (10): LeakyReLU (0.2) (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False) (12): Sigmoid () ) )
criterion = nn.BCELoss()
noise = torch.FloatTensor(batchSize, nz, 1, 1)
fixed_noise = torch.FloatTensor(batchSize, nz, 1, 1).normal_(0, 1)
D.cuda()
G.cuda()
criterion.cuda()
noise, fixed_noise = noise.cuda(), fixed_noise.cuda()
fixed_noise = Variable(fixed_noise)
# setup optimizers
D_optimizer = optim.Adam(D.parameters(), lr=lr, betas=(beta1, 0.999))
G_optimizer = optim.Adam(G.parameters(), lr=lr, betas=(beta1, 0.999))
# create figure for plotting
num_test_samples = 16
size_figure_grid = int(math.sqrt(num_test_samples))
fig, ax = plt.subplots(size_figure_grid, size_figure_grid, figsize=(6, 6))
for i, j in itertools.product(range(size_figure_grid), range(size_figure_grid)):
ax[i,j].get_xaxis().set_visible(False)
ax[i,j].get_yaxis().set_visible(False)
def display_samples(fake_images):
for k in range(num_test_samples):
i = k//4
j = k%4
img = fake_images[k].data.cpu() / 2 + 0.5
npimg = img.numpy()
ax[i,j].cla()
ax[i,j].imshow(np.transpose(npimg, (1, 2, 0)), cmap='Greys')
display.clear_output(wait=True)
display.display(plt.gcf())
display_every = 100
for epoch in range(nepochs):
i = 0
for data in dataloader:
############################
# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
###########################
# train with real
############################
# YOUR CODE
###########################
# train with fake
############################
# YOUR CODE
###########################
############################
# (2) Update G network: maximize log(D(G(z)))
############################
# YOUR CODE
###########################
print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f'
% (epoch, nepochs, i, len(dataloader),
D_loss.data[0], G_loss.data[0], D_x, D_G_z1, D_G_z2))
if i % display_every == 0:
# DISPLAY GRID OF SAMPLES
test_images = G(fixed_noise)
display_samples(test_images)
i += 1
[39/40][701/782] Loss_D: 0.0039 Loss_G: 6.7486 D(x): 1.0000 D(G(z)): 0.0039 / 0.0039 [39/40][702/782] Loss_D: 0.0008 Loss_G: 7.9971 D(x): 0.9999 D(G(z)): 0.0007 / 0.0007 [39/40][703/782] Loss_D: 0.0001 Loss_G: 10.1191 D(x): 1.0000 D(G(z)): 0.0001 / 0.0001 [39/40][704/782] Loss_D: 0.0001 Loss_G: 9.8917 D(x): 1.0000 D(G(z)): 0.0001 / 0.0001 [39/40][705/782] Loss_D: 0.0010 Loss_G: 7.4879 D(x): 0.9999 D(G(z)): 0.0009 / 0.0009 [39/40][706/782] Loss_D: 0.0011 Loss_G: 9.5507 D(x): 0.9991 D(G(z)): 0.0001 / 0.0001 [39/40][707/782] Loss_D: 0.0006 Loss_G: 9.1416 D(x): 0.9996 D(G(z)): 0.0001 / 0.0001 [39/40][708/782] Loss_D: 0.0043 Loss_G: 6.6727 D(x): 0.9984 D(G(z)): 0.0027 / 0.0027 [39/40][709/782] Loss_D: 0.0003 Loss_G: 9.9834 D(x): 0.9998 D(G(z)): 0.0001 / 0.0001 [39/40][710/782] Loss_D: 0.0001 Loss_G: 10.2500 D(x): 0.9999 D(G(z)): 0.0001 / 0.0001 [39/40][711/782] Loss_D: 0.0004 Loss_G: 8.1179 D(x): 1.0000 D(G(z)): 0.0004 / 0.0004 [39/40][712/782] Loss_D: 0.0009 Loss_G: 7.5218 D(x): 1.0000 D(G(z)): 0.0009 / 0.0009 [39/40][713/782] Loss_D: 0.0008 Loss_G: 7.9828 D(x): 0.9998 D(G(z)): 0.0006 / 0.0006 [39/40][714/782] Loss_D: 0.0038 Loss_G: 6.7694 D(x): 0.9999 D(G(z)): 0.0037 / 0.0037 [39/40][715/782] Loss_D: 0.0009 Loss_G: 10.4854 D(x): 0.9992 D(G(z)): 0.0001 / 0.0001 [39/40][716/782] Loss_D: 0.0004 Loss_G: 8.9052 D(x): 0.9999 D(G(z)): 0.0002 / 0.0002 [39/40][717/782] Loss_D: 0.0004 Loss_G: 8.7619 D(x): 0.9999 D(G(z)): 0.0002 / 0.0002 [39/40][718/782] Loss_D: 0.0001 Loss_G: 10.2157 D(x): 1.0000 D(G(z)): 0.0001 / 0.0001 [39/40][719/782] Loss_D: 0.0022 Loss_G: 7.3180 D(x): 0.9988 D(G(z)): 0.0011 / 0.0011 [39/40][720/782] Loss_D: 0.0001 Loss_G: 9.8606 D(x): 1.0000 D(G(z)): 0.0001 / 0.0001 [39/40][721/782] Loss_D: 0.0001 Loss_G: 10.0176 D(x): 1.0000 D(G(z)): 0.0001 / 0.0001 [39/40][722/782] Loss_D: 0.0010 Loss_G: 7.4533 D(x): 1.0000 D(G(z)): 0.0010 / 0.0010 [39/40][723/782] Loss_D: 0.0003 Loss_G: 10.2137 D(x): 0.9998 D(G(z)): 0.0001 / 0.0001 [39/40][724/782] Loss_D: 0.0013 Loss_G: 7.2376 D(x): 0.9999 D(G(z)): 0.0012 / 0.0012 [39/40][725/782] Loss_D: 0.0015 Loss_G: 7.2382 D(x): 0.9999 D(G(z)): 0.0014 / 0.0014 [39/40][726/782] Loss_D: 0.0014 Loss_G: 11.0547 D(x): 0.9987 D(G(z)): 0.0000 / 0.0000 [39/40][727/782] Loss_D: 0.0014 Loss_G: 7.2081 D(x): 0.9999 D(G(z)): 0.0014 / 0.0014 [39/40][728/782] Loss_D: 0.0006 Loss_G: 10.0759 D(x): 0.9995 D(G(z)): 0.0001 / 0.0001 [39/40][729/782] Loss_D: 0.0002 Loss_G: 11.2808 D(x): 0.9999 D(G(z)): 0.0000 / 0.0000 [39/40][730/782] Loss_D: 0.0007 Loss_G: 9.1643 D(x): 0.9995 D(G(z)): 0.0001 / 0.0001 [39/40][731/782] Loss_D: 0.0001 Loss_G: 9.3717 D(x): 1.0000 D(G(z)): 0.0001 / 0.0001 [39/40][732/782] Loss_D: 0.0006 Loss_G: 8.7607 D(x): 0.9996 D(G(z)): 0.0002 / 0.0002 [39/40][733/782] Loss_D: 0.0005 Loss_G: 8.2456 D(x): 0.9999 D(G(z)): 0.0004 / 0.0004 [39/40][734/782] Loss_D: 0.0128 Loss_G: 8.4694 D(x): 0.9909 D(G(z)): 0.0002 / 0.0002 [39/40][735/782] Loss_D: 0.0003 Loss_G: 8.4101 D(x): 1.0000 D(G(z)): 0.0003 / 0.0003 [39/40][736/782] Loss_D: 0.0005 Loss_G: 8.0811 D(x): 0.9999 D(G(z)): 0.0004 / 0.0004 [39/40][737/782] Loss_D: 0.0007 Loss_G: 7.8918 D(x): 0.9999 D(G(z)): 0.0006 / 0.0006 [39/40][738/782] Loss_D: 0.0008 Loss_G: 7.8373 D(x): 0.9999 D(G(z)): 0.0006 / 0.0006 [39/40][739/782] Loss_D: 0.0015 Loss_G: 7.1616 D(x): 0.9999 D(G(z)): 0.0014 / 0.0014 [39/40][740/782] Loss_D: 0.0056 Loss_G: 6.7568 D(x): 0.9975 D(G(z)): 0.0031 / 0.0031 [39/40][741/782] Loss_D: 0.0004 Loss_G: 8.6771 D(x): 0.9999 D(G(z)): 0.0003 / 0.0003 [39/40][742/782] Loss_D: 0.0035 Loss_G: 6.9673 D(x): 0.9999 D(G(z)): 0.0033 / 0.0033 [39/40][743/782] Loss_D: 0.0009 Loss_G: 8.1543 D(x): 0.9997 D(G(z)): 0.0006 / 0.0006 [39/40][744/782] Loss_D: 0.0001 Loss_G: 9.8637 D(x): 1.0000 D(G(z)): 0.0001 / 0.0001 [39/40][745/782] Loss_D: 0.0004 Loss_G: 8.7238 D(x): 0.9999 D(G(z)): 0.0002 / 0.0002 [39/40][746/782] Loss_D: 0.0026 Loss_G: 6.9812 D(x): 0.9997 D(G(z)): 0.0023 / 0.0023 [39/40][747/782] Loss_D: 0.0006 Loss_G: 8.3039 D(x): 0.9999 D(G(z)): 0.0004 / 0.0004 [39/40][748/782] Loss_D: 0.0004 Loss_G: 9.7751 D(x): 0.9996 D(G(z)): 0.0001 / 0.0001 [39/40][749/782] Loss_D: 0.0001 Loss_G: 10.0790 D(x): 1.0000 D(G(z)): 0.0001 / 0.0001 [39/40][750/782] Loss_D: 0.0002 Loss_G: 8.7802 D(x): 1.0000 D(G(z)): 0.0002 / 0.0002 [39/40][751/782] Loss_D: 0.0001 Loss_G: 10.7864 D(x): 0.9999 D(G(z)): 0.0000 / 0.0000 [39/40][752/782] Loss_D: 0.0003 Loss_G: 8.7146 D(x): 0.9999 D(G(z)): 0.0002 / 0.0002 [39/40][753/782] Loss_D: 0.0003 Loss_G: 9.2446 D(x): 0.9999 D(G(z)): 0.0001 / 0.0001 [39/40][754/782] Loss_D: 0.0002 Loss_G: 9.4009 D(x): 0.9999 D(G(z)): 0.0001 / 0.0001 [39/40][755/782] Loss_D: 0.0005 Loss_G: 7.9675 D(x): 0.9999 D(G(z)): 0.0005 / 0.0005 [39/40][756/782] Loss_D: 0.0003 Loss_G: 8.7279 D(x): 0.9999 D(G(z)): 0.0002 / 0.0002 [39/40][757/782] Loss_D: 0.0001 Loss_G: 10.0214 D(x): 1.0000 D(G(z)): 0.0001 / 0.0001 [39/40][758/782] Loss_D: 0.0008 Loss_G: 8.0148 D(x): 0.9997 D(G(z)): 0.0005 / 0.0005 [39/40][759/782] Loss_D: 0.0002 Loss_G: 9.3004 D(x): 1.0000 D(G(z)): 0.0001 / 0.0001 [39/40][760/782] Loss_D: 0.0001 Loss_G: 9.6296 D(x): 1.0000 D(G(z)): 0.0001 / 0.0001 [39/40][761/782] Loss_D: 0.0093 Loss_G: 6.5337 D(x): 0.9964 D(G(z)): 0.0054 / 0.0054 [39/40][762/782] Loss_D: 0.0001 Loss_G: 10.2362 D(x): 1.0000 D(G(z)): 0.0001 / 0.0001 [39/40][763/782] Loss_D: 0.0004 Loss_G: 8.4533 D(x): 0.9999 D(G(z)): 0.0004 / 0.0004 [39/40][764/782] Loss_D: 0.0006 Loss_G: 7.9432 D(x): 1.0000 D(G(z)): 0.0006 / 0.0006 [39/40][765/782] Loss_D: 0.0000 Loss_G: 12.8699 D(x): 1.0000 D(G(z)): 0.0000 / 0.0000 [39/40][766/782] Loss_D: 0.0010 Loss_G: 8.4724 D(x): 0.9993 D(G(z)): 0.0003 / 0.0003 [39/40][767/782] Loss_D: 0.0000 Loss_G: 13.6255 D(x): 1.0000 D(G(z)): 0.0000 / 0.0000 [39/40][768/782] Loss_D: 0.0000 Loss_G: 12.4288 D(x): 1.0000 D(G(z)): 0.0000 / 0.0000 [39/40][769/782] Loss_D: 0.0003 Loss_G: 8.3917 D(x): 1.0000 D(G(z)): 0.0003 / 0.0003 [39/40][770/782] Loss_D: 0.0002 Loss_G: 9.4292 D(x): 1.0000 D(G(z)): 0.0001 / 0.0001 [39/40][771/782] Loss_D: 0.0012 Loss_G: 7.2875 D(x): 1.0000 D(G(z)): 0.0012 / 0.0012 [39/40][772/782] Loss_D: 0.0001 Loss_G: 10.0948 D(x): 1.0000 D(G(z)): 0.0001 / 0.0001 [39/40][773/782] Loss_D: 0.0033 Loss_G: 6.9926 D(x): 0.9987 D(G(z)): 0.0020 / 0.0020 [39/40][774/782] Loss_D: 0.0001 Loss_G: 9.9110 D(x): 1.0000 D(G(z)): 0.0001 / 0.0001 [39/40][775/782] Loss_D: 0.0002 Loss_G: 9.8876 D(x): 1.0000 D(G(z)): 0.0001 / 0.0001 [39/40][776/782] Loss_D: 0.0001 Loss_G: 9.6672 D(x): 1.0000 D(G(z)): 0.0001 / 0.0001 [39/40][777/782] Loss_D: 0.0007 Loss_G: 7.8837 D(x): 0.9999 D(G(z)): 0.0006 / 0.0006 [39/40][778/782] Loss_D: 0.0001 Loss_G: 10.5697 D(x): 1.0000 D(G(z)): 0.0000 / 0.0000 [39/40][779/782] Loss_D: 0.0032 Loss_G: 6.8609 D(x): 0.9999 D(G(z)): 0.0031 / 0.0031 [39/40][780/782] Loss_D: 0.0002 Loss_G: 9.2162 D(x): 1.0000 D(G(z)): 0.0002 / 0.0002 [39/40][781/782] Loss_D: 0.0001 Loss_G: 9.4390 D(x): 1.0000 D(G(z)): 0.0001 / 0.0001