#!/usr/bin/env python # coding: utf-8 # Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks. # - Author: Sebastian Raschka # - GitHub Repository: https://github.com/rasbt/deeplearning-models # In[1]: get_ipython().run_line_magic('load_ext', 'watermark') get_ipython().run_line_magic('watermark', "-a 'Sebastian Raschka' -v -p torch") # - Runs on CPU or GPU (if available) # # Deep Convolutional Wasserstein GAN # Implementation of a deep convolutional Wasserstein GAN based on the paper # # - Arjovsky, M., Chintala, S., & Bottou, L. (2017). Wasserstein GAN. arXiv preprint arXiv:1701.07875. (https://arxiv.org/abs/1701.07875) # # The main differences to a conventional deep convolutional GAN are annotated in the code. In short, the main differences are # # 1. Not using a sigmoid activation function and just using a linear output layer for the critic (i.e., discriminator). # 2. Using label -1 instead of 1 for the real images; using label 1 instead of 0 for fake images. # 3. Using Wasserstein distance (loss) for training both the critic and the generator. # 4. After each weight update, clip the weights to be in range [-0.1, 0.1]. # 5. Train the critic 5 times for each generator training update. # # ## Imports # In[2]: import time import numpy as np import torch import torch.nn.functional as F from torchvision import datasets from torchvision import transforms import torch.nn as nn from torch.utils.data import DataLoader if torch.cuda.is_available(): torch.backends.cudnn.deterministic = True # ## Settings and Dataset # In[3]: ########################## ### SETTINGS ########################## # Device device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # Hyperparameters random_seed = 0 generator_learning_rate = 0.00005 discriminator_learning_rate = 0.00005 NUM_EPOCHS = 100 BATCH_SIZE = 128 LATENT_DIM = 100 IMG_SHAPE = (1, 28, 28) IMG_SIZE = 1 for x in IMG_SHAPE: IMG_SIZE *= x ## WGAN-specific settings num_iter_critic = 5 weight_clip_value = 0.01 ########################## ### MNIST DATASET ########################## # Note transforms.ToTensor() scales input images # to 0-1 range train_dataset = datasets.MNIST(root='data', train=True, transform=transforms.ToTensor(), download=True) test_dataset = datasets.MNIST(root='data', train=False, transform=transforms.ToTensor()) train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, num_workers=4, shuffle=True) test_loader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, num_workers=4, shuffle=False) # Checking the dataset for images, labels in train_loader: print('Image batch dimensions:', images.shape) print('Image label dimensions:', labels.shape) break # ## Model # In[4]: ########################## ### MODEL ########################## class Flatten(nn.Module): def forward(self, input): return input.view(input.size(0), -1) class Reshape1(nn.Module): def forward(self, input): return input.view(input.size(0), 64, 7, 7) def wasserstein_loss(y_true, y_pred): return torch.mean(y_true * y_pred) class GAN(torch.nn.Module): def __init__(self): super(GAN, self).__init__() self.generator = nn.Sequential( nn.Linear(LATENT_DIM, 3136, bias=False), nn.BatchNorm1d(num_features=3136), nn.LeakyReLU(inplace=True, negative_slope=0.0001), Reshape1(), nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=(3, 3), stride=(2, 2), padding=1, bias=False), nn.BatchNorm2d(num_features=32), nn.LeakyReLU(inplace=True, negative_slope=0.0001), #nn.Dropout2d(p=0.2), nn.ConvTranspose2d(in_channels=32, out_channels=16, kernel_size=(3, 3), stride=(2, 2), padding=1, bias=False), nn.BatchNorm2d(num_features=16), nn.LeakyReLU(inplace=True, negative_slope=0.0001), #nn.Dropout2d(p=0.2), nn.ConvTranspose2d(in_channels=16, out_channels=8, kernel_size=(3, 3), stride=(1, 1), padding=0, bias=False), nn.BatchNorm2d(num_features=8), nn.LeakyReLU(inplace=True, negative_slope=0.0001), #nn.Dropout2d(p=0.2), nn.ConvTranspose2d(in_channels=8, out_channels=1, kernel_size=(2, 2), stride=(1, 1), padding=0, bias=False), nn.Tanh() ) self.discriminator = nn.Sequential( nn.Conv2d(in_channels=1, out_channels=8, padding=1, kernel_size=(3, 3), stride=(2, 2), bias=False), nn.BatchNorm2d(num_features=8), nn.LeakyReLU(inplace=True, negative_slope=0.0001), #nn.Dropout2d(p=0.2), nn.Conv2d(in_channels=8, out_channels=16, padding=1, kernel_size=(3, 3), stride=(2, 2), bias=False), nn.BatchNorm2d(num_features=16), nn.LeakyReLU(inplace=True, negative_slope=0.0001), #nn.Dropout2d(p=0.2), nn.Conv2d(in_channels=16, out_channels=32, padding=1, kernel_size=(3, 3), stride=(2, 2), bias=False), nn.BatchNorm2d(num_features=32), nn.LeakyReLU(inplace=True, negative_slope=0.0001), #nn.Dropout2d(p=0.2), Flatten(), nn.Linear(512, 1), #nn.Sigmoid() ) def generator_forward(self, z): img = self.generator(z) return img def discriminator_forward(self, img): pred = model.discriminator(img) return pred.view(-1) # In[5]: torch.manual_seed(random_seed) #del model model = GAN() model = model.to(device) print(model) # In[6]: ### ## FOR DEBUGGING """ outputs = [] def hook(module, input, output): outputs.append(output) for i, layer in enumerate(model.discriminator): if isinstance(layer, torch.nn.modules.conv.Conv2d): model.discriminator[i].register_forward_hook(hook) #for i, layer in enumerate(model.generator): # if isinstance(layer, torch.nn.modules.ConvTranspose2d): # model.generator[i].register_forward_hook(hook) """ # In[7]: optim_gener = torch.optim.RMSprop(model.generator.parameters(), lr=generator_learning_rate) optim_discr = torch.optim.RMSprop(model.discriminator.parameters(), lr=discriminator_learning_rate) # ## Training # In[8]: start_time = time.time() discr_costs = [] gener_costs = [] for epoch in range(NUM_EPOCHS): model = model.train() for batch_idx, (features, targets) in enumerate(train_loader): # Normalize images to [-1, 1] range features = (features - 0.5)*2. features = features.view(-1, IMG_SIZE).to(device) targets = targets.to(device) # Regular GAN: # valid = torch.ones(targets.size(0)).float().to(device) # fake = torch.zeros(targets.size(0)).float().to(device) # WGAN: valid = -(torch.ones(targets.size(0)).float()).to(device) fake = torch.ones(targets.size(0)).float().to(device) ### FORWARD AND BACK PROP # -------------------------- # Train Generator # -------------------------- # Make new images z = torch.zeros((targets.size(0), LATENT_DIM)).uniform_(0.0, 1.0).to(device) generated_features = model.generator_forward(z) # Loss for fooling the discriminator discr_pred = model.discriminator_forward(generated_features.view(targets.size(0), 1, 28, 28)) # Regular GAN: # gener_loss = F.binary_cross_entropy_with_logits(discr_pred, valid) # WGAN: gener_loss = wasserstein_loss(valid, discr_pred) optim_gener.zero_grad() gener_loss.backward() optim_gener.step() # -------------------------- # Train Discriminator # -------------------------- # WGAN: Multiple loops for the discriminator for _ in range(num_iter_critic): discr_pred_real = model.discriminator_forward(features.view(targets.size(0), 1, 28, 28)) # Regular GAN: # real_loss = F.binary_cross_entropy_with_logits(discr_pred_real, valid) # WGAN: real_loss = wasserstein_loss(valid, discr_pred_real) discr_pred_fake = model.discriminator_forward(generated_features.view(targets.size(0), 1, 28, 28).detach()) # Regular GAN: # fake_loss = F.binary_cross_entropy_with_logits(discr_pred_fake, fake) # WGAN: fake_loss = wasserstein_loss(fake, discr_pred_fake) discr_loss = 0.5*(real_loss + fake_loss) optim_discr.zero_grad() discr_loss.backward() optim_discr.step() # WGAN: for p in model.discriminator.parameters(): p.data.clamp_(-weight_clip_value, weight_clip_value) discr_costs.append(discr_loss.item()) gener_costs.append(gener_loss.item()) ### LOGGING if not batch_idx % 100: print ('Epoch: %03d/%03d | Batch %03d/%03d | Gen/Dis Loss: %.4f/%.4f' %(epoch+1, NUM_EPOCHS, batch_idx, len(train_loader), gener_loss, discr_loss)) print('Time elapsed: %.2f min' % ((time.time() - start_time)/60)) print('Total Training Time: %.2f min' % ((time.time() - start_time)/60)) # In[9]: ### For Debugging """ for i in outputs: print(i.size()) """ # ## Evaluation # In[10]: get_ipython().run_line_magic('matplotlib', 'inline') import matplotlib.pyplot as plt # In[11]: ax1 = plt.subplot(1, 1, 1) ax1.plot(range(len(gener_costs)), gener_costs, label='Generator loss') ax1.plot(range(len(discr_costs)), discr_costs, label='Discriminator loss') ax1.set_xlabel('Iterations') ax1.set_ylabel('Loss') ax1.legend() ################### # Set scond x-axis ax2 = ax1.twiny() newlabel = list(range(NUM_EPOCHS+1)) iter_per_epoch = len(train_loader) newpos = [e*iter_per_epoch for e in newlabel] ax2.set_xticklabels(newlabel[::10]) ax2.set_xticks(newpos[::10]) ax2.xaxis.set_ticks_position('bottom') ax2.xaxis.set_label_position('bottom') ax2.spines['bottom'].set_position(('outward', 45)) ax2.set_xlabel('Epochs') ax2.set_xlim(ax1.get_xlim()) ################### plt.show() # In[15]: ########################## ### VISUALIZATION ########################## model.eval() # Make new images z = torch.zeros((10, LATENT_DIM)).uniform_(0.0, 1.0).to(device) generated_features = model.generator_forward(z) imgs = generated_features.view(-1, 28, 28) fig, axes = plt.subplots(nrows=1, ncols=10, figsize=(20, 2.5)) for i, ax in enumerate(axes): axes[i].imshow(imgs[i].to(torch.device('cpu')).detach(), cmap='binary') # In[16]: from torchsummary import summary model = model.to('cuda:0') summary(model.generator, input_size=(100,)) summary(model.discriminator, input_size=(1, 28, 28))