#!/usr/bin/env python # coding: utf-8 # [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/timsainb/tensorflow2-generative-models/blob/master/3.0-WGAN-GP-fashion-mnist.ipynb) # # ## Wasserstein GAN with Gradient Penalty (WGAN-GP) ([article](https://arxiv.org/abs/1701.07875)) # # WGAN-GP is a GAN that improves over the original loss function to improve training stability. # # ![wgan gp](https://github.com/timsainb/tensorflow2-generative-models/blob/f3360a819b5773692e943dfe181972a76b9d91bb/imgs/gan.png?raw=1) # ### Install packages if in colab # In[ ]: ### install necessary packages if in colab def run_subprocess_command(cmd): process = subprocess.Popen(cmd.split(), stdout=subprocess.PIPE) for line in process.stdout: print(line.decode().strip()) import sys, subprocess IN_COLAB = 'google.colab' in sys.modules colab_requirements = ['pip install tf-nightly-gpu-2.0-preview==2.0.0.dev20190513'] if IN_COLAB: for i in colab_requirements: run_subprocess_command(i) # ### load packages # In[ ]: # In[ ]: # make visible the only one GPU get_ipython().run_line_magic('env', 'CUDA_VISIBLE_DEVICES=3') # In[ ]: import tensorflow as tf import numpy as np import matplotlib.pyplot as plt from tqdm.autonotebook import tqdm get_ipython().run_line_magic('matplotlib', 'inline') from IPython import display import pandas as pd # In[ ]: print(tf.__version__) # ### Create a fashion-MNIST dataset # In[ ]: TRAIN_BUF=60000 BATCH_SIZE=512 TEST_BUF=10000 DIMS = (28,28,1) N_TRAIN_BATCHES =int(TRAIN_BUF/BATCH_SIZE) N_TEST_BATCHES = int(TEST_BUF/BATCH_SIZE) # In[ ]: # load dataset (train_images, _), (test_images, _) = tf.keras.datasets.fashion_mnist.load_data() # split dataset train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype( "float32" ) / 255.0 test_images = test_images.reshape(test_images.shape[0], 28, 28, 1).astype("float32") / 255.0 # batch datasets train_dataset = ( tf.data.Dataset.from_tensor_slices(train_images) .shuffle(TRAIN_BUF) .batch(BATCH_SIZE) ) test_dataset = ( tf.data.Dataset.from_tensor_slices(test_images) .shuffle(TEST_BUF) .batch(BATCH_SIZE) ) # ### Define the network as tf.keras.model object # In[ ]: class WGAN(tf.keras.Model): """[summary] I used github/LynnHo/DCGAN-LSGAN-WGAN-GP-DRAGAN-Tensorflow-2/ as a reference on this. Extends: tf.keras.Model """ def __init__(self, **kwargs): super(WGAN, self).__init__() self.__dict__.update(kwargs) self.gen = tf.keras.Sequential(self.gen) self.disc = tf.keras.Sequential(self.disc) def generate(self, z): return self.gen(z) def discriminate(self, x): return self.disc(x) def compute_loss(self, x): """ passes through the network and computes loss """ ### pass through network # generating noise from a uniform distribution z_samp = tf.random.normal([x.shape[0], 1, 1, self.n_Z]) # run noise through generator x_gen = self.generate(z_samp) # discriminate x and x_gen logits_x = self.discriminate(x) logits_x_gen = self.discriminate(x_gen) # gradient penalty d_regularizer = self.gradient_penalty(x, x_gen) ### losses disc_loss = ( tf.reduce_mean(logits_x) - tf.reduce_mean(logits_x_gen) + d_regularizer * self.gradient_penalty_weight ) # losses of fake with label "1" gen_loss = tf.reduce_mean(logits_x_gen) return disc_loss, gen_loss def compute_gradients(self, x): """ passes through the network and computes loss """ ### pass through network with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape: disc_loss, gen_loss = self.compute_loss(x) # compute gradients gen_gradients = gen_tape.gradient(gen_loss, self.gen.trainable_variables) disc_gradients = disc_tape.gradient(disc_loss, self.disc.trainable_variables) return gen_gradients, disc_gradients def apply_gradients(self, gen_gradients, disc_gradients): self.gen_optimizer.apply_gradients( zip(gen_gradients, self.gen.trainable_variables) ) self.disc_optimizer.apply_gradients( zip(disc_gradients, self.disc.trainable_variables) ) def gradient_penalty(self, x, x_gen): epsilon = tf.random.uniform([x.shape[0], 1, 1, 1], 0.0, 1.0) x_hat = epsilon * x + (1 - epsilon) * x_gen with tf.GradientTape() as t: t.watch(x_hat) d_hat = self.discriminate(x_hat) gradients = t.gradient(d_hat, x_hat) ddx = tf.sqrt(tf.reduce_sum(gradients ** 2, axis=[1, 2])) d_regularizer = tf.reduce_mean((ddx - 1.0) ** 2) return d_regularizer @tf.function def train(self, train_x): gen_gradients, disc_gradients = self.compute_gradients(train_x) self.apply_gradients(gen_gradients, disc_gradients) # ### Define the network architecture # In[ ]: N_Z = 64 generator = [ tf.keras.layers.Dense(units=7 * 7 * 64, activation="relu"), tf.keras.layers.Reshape(target_shape=(7, 7, 64)), tf.keras.layers.Conv2DTranspose( filters=64, kernel_size=3, strides=(2, 2), padding="SAME", activation="relu" ), tf.keras.layers.Conv2DTranspose( filters=32, kernel_size=3, strides=(2, 2), padding="SAME", activation="relu" ), tf.keras.layers.Conv2DTranspose( filters=1, kernel_size=3, strides=(1, 1), padding="SAME", activation="sigmoid" ), ] discriminator = [ tf.keras.layers.InputLayer(input_shape=DIMS), tf.keras.layers.Conv2D( filters=32, kernel_size=3, strides=(2, 2), activation="relu" ), tf.keras.layers.Conv2D( filters=64, kernel_size=3, strides=(2, 2), activation="relu" ), tf.keras.layers.Flatten(), tf.keras.layers.Dense(units=1, activation="sigmoid"), ] # ### Create Model # In[ ]: # optimizers gen_optimizer = tf.keras.optimizers.Adam(0.0001, beta_1=0.5) disc_optimizer = tf.keras.optimizers.RMSprop(0.0005)# train the model # model model = WGAN( gen = generator, disc = discriminator, gen_optimizer = gen_optimizer, disc_optimizer = disc_optimizer, n_Z = N_Z, gradient_penalty_weight = 10.0 ) # ### Train the model # In[ ]: # exampled data for plotting results def plot_reconstruction(model, nex=8, zm=2): samples = model.generate(tf.random.normal(shape=(BATCH_SIZE, N_Z))) fig, axs = plt.subplots(ncols=nex, nrows=1, figsize=(zm * nex, zm)) for axi in range(nex): axs[axi].matshow( samples.numpy()[axi].squeeze(), cmap=plt.cm.Greys, vmin=0, vmax=1 ) axs[axi].axis('off') plt.show() # In[ ]: # a pandas dataframe to save the loss information to losses = pd.DataFrame(columns = ['disc_loss', 'gen_loss']) # In[ ]: n_epochs = 200 for epoch in range(n_epochs): # train for batch, train_x in tqdm( zip(range(N_TRAIN_BATCHES), train_dataset), total=N_TRAIN_BATCHES ): model.train(train_x) # test on holdout loss = [] for batch, test_x in tqdm( zip(range(N_TEST_BATCHES), test_dataset), total=N_TEST_BATCHES ): loss.append(model.compute_loss(train_x)) losses.loc[len(losses)] = np.mean(loss, axis=0) # plot results display.clear_output() print( "Epoch: {} | disc_loss: {} | gen_loss: {}".format( epoch, losses.disc_loss.values[-1], losses.gen_loss.values[-1] ) ) plot_reconstruction(model) # In[ ]: plt.plot(losses.gen_loss.values) # In[ ]: plt.plot(losses.disc_loss.values) # In[ ]: