#!/usr/bin/env python # coding: utf-8 # # Fitting a toy distribution with a GAN # # Implementations of the vanilla GAN [1], least-squares GAN [2], Wasserstein GAN [3] (GP version [4]), and the Hinge-loss GAN [5] # # [1] Goodfellow, Ian, et al. "Generative adversarial nets." # Advances in neural information processing systems. 2014. # [2] Mao, Xudong, et al. "Least squares generative adversarial networks." # Proceedings of the IEEE international conference on computer vision. 2017. # [3] Arjovsky, Martin, Soumith Chintala, and Léon Bottou. "Wasserstein # generative adversarial networks." Proceedings of the 34th International # Conference on Machine Learning-Volume 70. 2017. # [4] Gulrajani, Ishaan, et al. "Improved training of Wasserstein GANs." # Advances in neural information processing systems. 2017. # [5] Lim, Jae Hyun, and Jong Chul Ye. "Geometric GAN." # arXiv preprint arXiv:1705.02894 (2017). # [6] Zhang, Han, et al. "Self-attention generative adversarial networks." # International Conference on Machine Learning. PMLR, 2019. # ## Setup notebook # In[1]: from typing import * from functools import partial from glob import glob import math import os import random import sys gpu_id = 0 os.environ["CUDA_VISIBLE_DEVICES"] = f'{gpu_id}' import matplotlib.pyplot as plt import numpy as np import seaborn as sns sns.set_style('white') import torch from torch import nn import torch.distributions as D import torch.nn.functional as F from torch.utils.data import TensorDataset, DataLoader # Report versions # In[2]: print('numpy version: {}'.format(np.__version__)) from matplotlib import __version__ as mplver print('matplotlib version: {}'.format(mplver)) print(f'pytorch version: {torch.__version__}') # In[3]: pv = sys.version_info print('python version: {}.{}.{}'.format(pv.major, pv.minor, pv.micro)) # Check GPU(s) # In[4]: get_ipython().system('nvidia-smi | head -n 4') # In[5]: assert torch.cuda.is_available() device = torch.device('cuda') torch.backends.cudnn.benchmark = True # Set seeds for better reproducibility. See [this note](https://pytorch.org/docs/stable/notes/faq.html#my-data-loader-workers-return-identical-random-numbers) before using multiprocessing. # In[6]: seed = 9 random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) # ## Create dataset # In[7]: loc = torch.tensor([0.,0.]) scale = torch.tensor([1.,1.]) base_dist = D.Normal(loc, scale) # In[8]: mix_comp = D.Bernoulli(torch.tensor([0.5])) g1 = D.MultivariateNormal(torch.tensor([-5.,-3.]), torch.tensor([[2., 0.5],[0.5, 3.]])) g2 = D.MultivariateNormal(torch.tensor([ 4., 4.]), torch.tensor([[1.,-0.5],[-0.5,2.]])) def p_data_mix(n_samp:int): mix_comps = mix_comp.sample((n_samp,)) mc0 = (mix_comps < 1)[:,0] mc1 = (mix_comps > 0)[:,0] samples = torch.zeros(n_samp, 2) samples[mc0,:] = g1.sample((mc0.sum(),)) samples[mc1,:] = g2.sample((mc1.sum(),)) return samples def p_data_single(n_samp:int): return g2.sample((n_samp,)) sample_p_data = p_data_mix # In[9]: n_samples = 10000 x_real = sample_p_data(n_samples) xr, yr = x_real[:,0], x_real[:,1] # In[10]: g = sns.jointplot(x=xr, y=yr, kind="hex"); g.ax_marg_x.set_title('True distribution'); plt.savefig('true_dist.svg') # In[11]: xlim = g.ax_joint.get_xlim() ylim = g.ax_joint.get_ylim() # ## Define a generator and a discriminator # In[12]: def G_layer_BN(in_c:int, out_c:int): return nn.Sequential( nn.Linear(in_c, out_c, bias=False), nn.BatchNorm1d(out_c), nn.LeakyReLU(inplace=True)) def G_layer_SN(in_c:int, out_c:int): return nn.Sequential( nn.utils.spectral_norm(nn.Linear(in_c, out_c)), nn.LeakyReLU(inplace=True)) class Generator(nn.Sequential): _layer = staticmethod(G_layer_SN) def __init__(self, in_dim:int, out_dim:int, n_layers:int=5, hidden_dim:int=128, dropout_rate:float=0.): super().__init__() self.add_module('h1', self._layer(in_dim, hidden_dim)) for i in range(2, n_layers): self.add_module(f'h{i}', self._layer(hidden_dim, hidden_dim)) if dropout_rate > 0.: self.add_module(f'dropout', nn.Dropout(dropout_rate)) self.add_module(f'h{n_layers}', nn.Linear(hidden_dim, out_dim)) def D_layer_BN(in_c:int, out_c:int): return nn.Sequential( nn.Linear(in_c, out_c, bias=False), nn.BatchNorm1d(out_c), nn.LeakyReLU(0.2, inplace=True)) def D_layer_SN(in_c:int, out_c:int): return nn.Sequential( nn.utils.spectral_norm(nn.Linear(in_c, out_c)), nn.LeakyReLU(0.1, inplace=True)) class Discriminator(Generator): _layer = staticmethod(D_layer_SN) # In[13]: hidden_dim = 128 n_layers = 4 dropout_rate = 0. g_args = (2, 2, n_layers, hidden_dim, dropout_rate) d_args = (2, 1, n_layers, hidden_dim, dropout_rate) generator = Generator(*g_args).to(device) discriminator = Discriminator(*d_args).to(device) # In[14]: D_final_activation = None if D_final_activation == 'tanh': discriminator.add_module('tanh', nn.Tanh()) elif D_final_activation == 'sigmoid': discriminator.add_module('sigmoid', nn.Sigmoid()) # In[15]: def num_params(model): return sum(p.numel() for p in model.parameters() if p.requires_grad) # In[16]: print(f'Number of trainable parameters in generator: {num_params(generator)}') print(f'Number of trainable parameters in discriminator: {num_params(discriminator)}') # In[17]: def weights_init(m): name = m.__class__.__name__ if 'Linear' in name or 'BatchNorm' in name: nn.init.normal_(m.weight.data, 0., 0.02) if hasattr(m, 'bias'): if m.bias is not None: nn.init.constant_(m.bias.data, 0.) generator.apply(weights_init); discriminator.apply(weights_init); # In[18]: def gradient_penalty(y, x): weight = torch.ones_like(y) grad = torch.autograd.grad(outputs=y, inputs=x, grad_outputs=weight, retain_graph=True, create_graph=True, only_inputs=True)[0] return torch.mean((grad.norm(dim=1) - 1) ** 2) # ## Train the generator and discriminator # In[19]: n_epochs = 5000 n_samples_per_epoch = n_samples print_loss_rate = 500 batch_size = n_samples_per_epoch use_minibatches = batch_size != n_samples_per_epoch # In[20]: x_real = x_real.to(device) # In[21]: gan_type = 'hinge' betas = (0.,0.999) # parameters from self-attention GAN [6] G_lr = 1e-4 D_lr = 4e-4 D_steps = 1 use_gp = False gp_weight = 10. G_opt = torch.optim.Adam(generator.parameters(), lr=G_lr, betas=betas) D_opt = torch.optim.Adam(discriminator.parameters(), lr=D_lr, betas=betas) # In[22]: real_label = 1. fake_label = 0. real_labels = real_label * torch.ones((n_samples_per_epoch, 1)) fake_labels = fake_label * torch.ones((n_samples_per_epoch, 1)) real_labels = real_labels.to(device) fake_labels = fake_labels.to(device) # In[23]: if use_minibatches: MyDataLoader = partial(DataLoader, batch_size=batch_size, shuffle=True) x_real_dataset = TensorDataset(x_real) x_real_dataloader = MyDataLoader(x_real_dataset) # In[24]: def reset_grad(): D_opt.zero_grad() G_opt.zero_grad() # In[25]: def train_discriminator(x_real, z): for _ in range(D_steps): reset_grad() # discriminate real samples D_real = discriminator(x_real) if gan_type == 'vanilla': loss_real = F.binary_cross_entropy_with_logits(D_real, real_labels) D_x = torch.sigmoid(D_real).mean().item() elif gan_type == 'lsgan': loss_real = torch.mean((D_real - real_label)**2) D_x = D_real.mean().item() elif gan_type == 'wgan-gp': loss_real = D_real.mean() D_x = loss_real.item() elif gan_type == 'hinge': loss_real = F.relu(1. - D_real).mean() D_x = loss_real.item() else: raise NotImplementedError(f'{gan_type} not implemented.') # discriminate fake samples with torch.no_grad(): x_fake = generator(z) D_fake = discriminator(x_fake) if gan_type == 'vanilla': loss_fake = F.binary_cross_entropy_with_logits(D_fake, fake_labels) D_G_z_1 = torch.sigmoid(D_fake).mean().item() elif gan_type == 'lsgan': loss_fake = torch.mean((D_fake - fake_label)**2) D_G_z_1 = D_fake.mean().item() elif gan_type == 'wgan-gp': loss_fake = D_fake.mean() D_G_z_1 = loss_fake.item() elif gan_type == 'hinge': loss_fake = F.relu(1. + D_fake).mean() D_G_z_1 = loss_fake.item() else: raise NotImplementedError(f'{gan_type} not implemented.') if use_gp or gan_type == 'wgan-gp': eps = torch.rand(batch_size,1).to(device) x_hat = (eps*x_real + (1.-eps)*x_fake) x_hat.requires_grad_(True) D_x_hat = discriminator(x_hat) gp = gradient_penalty(D_x_hat, x_hat) if gan_type != 'wgan-gp': D_loss = 0.5 * (loss_fake + loss_real) if use_gp: D_loss += gp_weight * gp else: D_loss = loss_fake - loss_real + gp_weight * gp D_loss.backward() D_opt.step() return D_loss.item(), D_x, D_G_z_1 def train_generator(z): reset_grad() x_fake = generator(z) D_fake = discriminator(x_fake) if gan_type == 'vanilla': G_loss = F.binary_cross_entropy_with_logits(D_fake, real_labels) D_G_z_2 = torch.sigmoid(D_fake).mean().item() elif gan_type == 'lsgan': G_loss = 0.5 * torch.mean(D_fake**2) D_G_z_2 = D_fake.mean().item() elif gan_type == 'wgan-gp': G_loss = -D_fake.mean() D_G_z_2 = G_loss.item() elif gan_type == 'hinge': G_loss = -D_fake.mean() D_G_z_2 = G_loss.item() else: raise NotImplementedError(f'{gan_type} not implemented.') G_loss.backward() G_opt.step() return G_loss.item(), D_G_z_2 # In[26]: for i in range(1, n_epochs+1): z_full = base_dist.sample((n_samples_per_epoch,)).to(device) if use_minibatches: z_dataset = TensorDataset(z_full) z_loader = MyDataLoader(z_dataset) # train discriminator generator.eval(); discriminator.train(); if use_minibatches: for x_r, z in zip(x_real_dataloader, z_loader): x_r, z = x_r[0], z[0] D_loss, D_x, D_G_z_1 = train_discriminator(x_r, z) else: D_loss, D_x, D_G_z_1 = train_discriminator(x_real, z_full) # train generator generator.train(); discriminator.eval(); if use_minibatches: for z in z_loader: z = z[0] G_loss, D_G_z_2 = train_generator(z) else: G_loss, D_G_z_2 = train_generator(z_full) if i % print_loss_rate == 0: print(f'Epoch {i}: D_loss={D_loss:0.4f}, G_loss={G_loss:0.4f}, ' f'D(x): {D_x:0.3f}, D(G(z)): {D_G_z_1:0.3f}/{D_G_z_2:0.3f}') # In[27]: z = base_dist.sample((n_samples_per_epoch,)).to(device) generator.eval() with torch.no_grad(): x_fake = generator(z).detach().cpu().numpy() xf, yf = x_fake[:,0], x_fake[:,1] # In[28]: sns.jointplot(x=xr, y=yr, kind="hex"); g = sns.jointplot(x=xf, y=yf, kind="hex", xlim=xlim, ylim=ylim); g.ax_marg_x.set_title('Fit distribution'); plt.savefig('fit_dist.svg')