Implementations of the vanilla GAN [1], least-squares GAN [2], Wasserstein GAN [3] (GP version [4]), and the Hinge-loss GAN [5]
[1] Goodfellow, Ian, et al. "Generative adversarial nets."
Advances in neural information processing systems. 2014.
[2] Mao, Xudong, et al. "Least squares generative adversarial networks."
Proceedings of the IEEE international conference on computer vision. 2017.
[3] Arjovsky, Martin, Soumith Chintala, and Léon Bottou. "Wasserstein
generative adversarial networks." Proceedings of the 34th International
Conference on Machine Learning-Volume 70. 2017.
[4] Gulrajani, Ishaan, et al. "Improved training of Wasserstein GANs."
Advances in neural information processing systems. 2017.
[5] Lim, Jae Hyun, and Jong Chul Ye. "Geometric GAN."
arXiv preprint arXiv:1705.02894 (2017).
[6] Zhang, Han, et al. "Self-attention generative adversarial networks."
International Conference on Machine Learning. PMLR, 2019.
from typing import *
from functools import partial
from glob import glob
import math
import os
import random
import sys
gpu_id = 0
os.environ["CUDA_VISIBLE_DEVICES"] = f'{gpu_id}'
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
sns.set_style('white')
import torch
from torch import nn
import torch.distributions as D
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
Report versions
print('numpy version: {}'.format(np.__version__))
from matplotlib import __version__ as mplver
print('matplotlib version: {}'.format(mplver))
print(f'pytorch version: {torch.__version__}')
numpy version: 1.19.1 matplotlib version: 3.3.1 pytorch version: 1.6.0
pv = sys.version_info
print('python version: {}.{}.{}'.format(pv.major, pv.minor, pv.micro))
python version: 3.8.5
Check GPU(s)
!nvidia-smi | head -n 4
Sun Sep 20 14:02:12 2020 +-----------------------------------------------------------------------------+ | NVIDIA-SMI 430.40 Driver Version: 430.40 CUDA Version: 10.1 | |-------------------------------+----------------------+----------------------+
assert torch.cuda.is_available()
device = torch.device('cuda')
torch.backends.cudnn.benchmark = True
Set seeds for better reproducibility. See this note before using multiprocessing.
seed = 9
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
loc = torch.tensor([0.,0.])
scale = torch.tensor([1.,1.])
base_dist = D.Normal(loc, scale)
mix_comp = D.Bernoulli(torch.tensor([0.5]))
g1 = D.MultivariateNormal(torch.tensor([-5.,-3.]),
torch.tensor([[2., 0.5],[0.5, 3.]]))
g2 = D.MultivariateNormal(torch.tensor([ 4., 4.]),
torch.tensor([[1.,-0.5],[-0.5,2.]]))
def p_data_mix(n_samp:int):
mix_comps = mix_comp.sample((n_samp,))
mc0 = (mix_comps < 1)[:,0]
mc1 = (mix_comps > 0)[:,0]
samples = torch.zeros(n_samp, 2)
samples[mc0,:] = g1.sample((mc0.sum(),))
samples[mc1,:] = g2.sample((mc1.sum(),))
return samples
def p_data_single(n_samp:int):
return g2.sample((n_samp,))
sample_p_data = p_data_mix
n_samples = 10000
x_real = sample_p_data(n_samples)
xr, yr = x_real[:,0], x_real[:,1]
g = sns.jointplot(x=xr, y=yr, kind="hex");
g.ax_marg_x.set_title('True distribution');
plt.savefig('true_dist.svg')
xlim = g.ax_joint.get_xlim()
ylim = g.ax_joint.get_ylim()
def G_layer_BN(in_c:int, out_c:int):
return nn.Sequential(
nn.Linear(in_c, out_c, bias=False),
nn.BatchNorm1d(out_c),
nn.LeakyReLU(inplace=True))
def G_layer_SN(in_c:int, out_c:int):
return nn.Sequential(
nn.utils.spectral_norm(nn.Linear(in_c, out_c)),
nn.LeakyReLU(inplace=True))
class Generator(nn.Sequential):
_layer = staticmethod(G_layer_SN)
def __init__(self, in_dim:int, out_dim:int, n_layers:int=5,
hidden_dim:int=128, dropout_rate:float=0.):
super().__init__()
self.add_module('h1', self._layer(in_dim, hidden_dim))
for i in range(2, n_layers):
self.add_module(f'h{i}', self._layer(hidden_dim, hidden_dim))
if dropout_rate > 0.:
self.add_module(f'dropout', nn.Dropout(dropout_rate))
self.add_module(f'h{n_layers}', nn.Linear(hidden_dim, out_dim))
def D_layer_BN(in_c:int, out_c:int):
return nn.Sequential(
nn.Linear(in_c, out_c, bias=False),
nn.BatchNorm1d(out_c),
nn.LeakyReLU(0.2, inplace=True))
def D_layer_SN(in_c:int, out_c:int):
return nn.Sequential(
nn.utils.spectral_norm(nn.Linear(in_c, out_c)),
nn.LeakyReLU(0.1, inplace=True))
class Discriminator(Generator):
_layer = staticmethod(D_layer_SN)
hidden_dim = 128
n_layers = 4
dropout_rate = 0.
g_args = (2, 2, n_layers, hidden_dim, dropout_rate)
d_args = (2, 1, n_layers, hidden_dim, dropout_rate)
generator = Generator(*g_args).to(device)
discriminator = Discriminator(*d_args).to(device)
D_final_activation = None
if D_final_activation == 'tanh':
discriminator.add_module('tanh', nn.Tanh())
elif D_final_activation == 'sigmoid':
discriminator.add_module('sigmoid', nn.Sigmoid())
def num_params(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Number of trainable parameters in generator: {num_params(generator)}')
print(f'Number of trainable parameters in discriminator: {num_params(discriminator)}')
Number of trainable parameters in generator: 33666 Number of trainable parameters in discriminator: 33537
def weights_init(m):
name = m.__class__.__name__
if 'Linear' in name or 'BatchNorm' in name:
nn.init.normal_(m.weight.data, 0., 0.02)
if hasattr(m, 'bias'):
if m.bias is not None:
nn.init.constant_(m.bias.data, 0.)
generator.apply(weights_init);
discriminator.apply(weights_init);
def gradient_penalty(y, x):
weight = torch.ones_like(y)
grad = torch.autograd.grad(outputs=y,
inputs=x,
grad_outputs=weight,
retain_graph=True,
create_graph=True,
only_inputs=True)[0]
return torch.mean((grad.norm(dim=1) - 1) ** 2)
n_epochs = 5000
n_samples_per_epoch = n_samples
print_loss_rate = 500
batch_size = n_samples_per_epoch
use_minibatches = batch_size != n_samples_per_epoch
x_real = x_real.to(device)
gan_type = 'hinge'
betas = (0.,0.999) # parameters from self-attention GAN [6]
G_lr = 1e-4
D_lr = 4e-4
D_steps = 1
use_gp = False
gp_weight = 10.
G_opt = torch.optim.Adam(generator.parameters(), lr=G_lr, betas=betas)
D_opt = torch.optim.Adam(discriminator.parameters(), lr=D_lr, betas=betas)
real_label = 1.
fake_label = 0.
real_labels = real_label * torch.ones((n_samples_per_epoch, 1))
fake_labels = fake_label * torch.ones((n_samples_per_epoch, 1))
real_labels = real_labels.to(device)
fake_labels = fake_labels.to(device)
if use_minibatches:
MyDataLoader = partial(DataLoader, batch_size=batch_size, shuffle=True)
x_real_dataset = TensorDataset(x_real)
x_real_dataloader = MyDataLoader(x_real_dataset)
def reset_grad():
D_opt.zero_grad()
G_opt.zero_grad()
def train_discriminator(x_real, z):
for _ in range(D_steps):
reset_grad()
# discriminate real samples
D_real = discriminator(x_real)
if gan_type == 'vanilla':
loss_real = F.binary_cross_entropy_with_logits(D_real, real_labels)
D_x = torch.sigmoid(D_real).mean().item()
elif gan_type == 'lsgan':
loss_real = torch.mean((D_real - real_label)**2)
D_x = D_real.mean().item()
elif gan_type == 'wgan-gp':
loss_real = D_real.mean()
D_x = loss_real.item()
elif gan_type == 'hinge':
loss_real = F.relu(1. - D_real).mean()
D_x = loss_real.item()
else:
raise NotImplementedError(f'{gan_type} not implemented.')
# discriminate fake samples
with torch.no_grad():
x_fake = generator(z)
D_fake = discriminator(x_fake)
if gan_type == 'vanilla':
loss_fake = F.binary_cross_entropy_with_logits(D_fake, fake_labels)
D_G_z_1 = torch.sigmoid(D_fake).mean().item()
elif gan_type == 'lsgan':
loss_fake = torch.mean((D_fake - fake_label)**2)
D_G_z_1 = D_fake.mean().item()
elif gan_type == 'wgan-gp':
loss_fake = D_fake.mean()
D_G_z_1 = loss_fake.item()
elif gan_type == 'hinge':
loss_fake = F.relu(1. + D_fake).mean()
D_G_z_1 = loss_fake.item()
else:
raise NotImplementedError(f'{gan_type} not implemented.')
if use_gp or gan_type == 'wgan-gp':
eps = torch.rand(batch_size,1).to(device)
x_hat = (eps*x_real + (1.-eps)*x_fake)
x_hat.requires_grad_(True)
D_x_hat = discriminator(x_hat)
gp = gradient_penalty(D_x_hat, x_hat)
if gan_type != 'wgan-gp':
D_loss = 0.5 * (loss_fake + loss_real)
if use_gp:
D_loss += gp_weight * gp
else:
D_loss = loss_fake - loss_real + gp_weight * gp
D_loss.backward()
D_opt.step()
return D_loss.item(), D_x, D_G_z_1
def train_generator(z):
reset_grad()
x_fake = generator(z)
D_fake = discriminator(x_fake)
if gan_type == 'vanilla':
G_loss = F.binary_cross_entropy_with_logits(D_fake, real_labels)
D_G_z_2 = torch.sigmoid(D_fake).mean().item()
elif gan_type == 'lsgan':
G_loss = 0.5 * torch.mean(D_fake**2)
D_G_z_2 = D_fake.mean().item()
elif gan_type == 'wgan-gp':
G_loss = -D_fake.mean()
D_G_z_2 = G_loss.item()
elif gan_type == 'hinge':
G_loss = -D_fake.mean()
D_G_z_2 = G_loss.item()
else:
raise NotImplementedError(f'{gan_type} not implemented.')
G_loss.backward()
G_opt.step()
return G_loss.item(), D_G_z_2
for i in range(1, n_epochs+1):
z_full = base_dist.sample((n_samples_per_epoch,)).to(device)
if use_minibatches:
z_dataset = TensorDataset(z_full)
z_loader = MyDataLoader(z_dataset)
# train discriminator
generator.eval();
discriminator.train();
if use_minibatches:
for x_r, z in zip(x_real_dataloader, z_loader):
x_r, z = x_r[0], z[0]
D_loss, D_x, D_G_z_1 = train_discriminator(x_r, z)
else:
D_loss, D_x, D_G_z_1 = train_discriminator(x_real, z_full)
# train generator
generator.train();
discriminator.eval();
if use_minibatches:
for z in z_loader:
z = z[0]
G_loss, D_G_z_2 = train_generator(z)
else:
G_loss, D_G_z_2 = train_generator(z_full)
if i % print_loss_rate == 0:
print(f'Epoch {i}: D_loss={D_loss:0.4f}, G_loss={G_loss:0.4f}, '
f'D(x): {D_x:0.3f}, D(G(z)): {D_G_z_1:0.3f}/{D_G_z_2:0.3f}')
Epoch 500: D_loss=0.3533, G_loss=1.0013, D(x): 0.528, D(G(z)): 0.179/1.001 Epoch 1000: D_loss=0.4411, G_loss=0.9532, D(x): 0.762, D(G(z)): 0.120/0.953 Epoch 1500: D_loss=0.4974, G_loss=0.9421, D(x): 0.969, D(G(z)): 0.026/0.942 Epoch 2000: D_loss=0.4967, G_loss=0.9665, D(x): 0.989, D(G(z)): 0.004/0.967 Epoch 2500: D_loss=0.4981, G_loss=0.9756, D(x): 0.995, D(G(z)): 0.002/0.976 Epoch 3000: D_loss=0.4981, G_loss=0.9996, D(x): 0.991, D(G(z)): 0.006/1.000 Epoch 3500: D_loss=0.4980, G_loss=0.9862, D(x): 0.995, D(G(z)): 0.001/0.986 Epoch 4000: D_loss=0.4985, G_loss=0.9906, D(x): 0.996, D(G(z)): 0.001/0.991 Epoch 4500: D_loss=0.4981, G_loss=1.0054, D(x): 0.990, D(G(z)): 0.006/1.005 Epoch 5000: D_loss=0.4988, G_loss=1.0052, D(x): 0.990, D(G(z)): 0.007/1.005
z = base_dist.sample((n_samples_per_epoch,)).to(device)
generator.eval()
with torch.no_grad():
x_fake = generator(z).detach().cpu().numpy()
xf, yf = x_fake[:,0], x_fake[:,1]
sns.jointplot(x=xr, y=yr, kind="hex");
g = sns.jointplot(x=xf, y=yf, kind="hex", xlim=xlim, ylim=ylim);
g.ax_marg_x.set_title('Fit distribution');
plt.savefig('fit_dist.svg')