Based on the original Generative Adversarial Network (GAN), as introduced by Goodfellow et al. in 2014 [1]
COLAB = True
if COLAB:
from google.colab import drive
drive.mount('/content/drive')
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
!pip install tensorview
Requirement already satisfied: tensorview in /usr/local/lib/python3.6/dist-packages (0.4.1) Requirement already satisfied: matplotlib in /usr/local/lib/python3.6/dist-packages (from tensorview) (3.2.2) Requirement already satisfied: pyecharts-snapshot>=0.1.10tensorflow>=2.0.0 in /usr/local/lib/python3.6/dist-packages (from tensorview) (0.2.0) Requirement already satisfied: pandas>=0.24.1 in /usr/local/lib/python3.6/dist-packages (from tensorview) (1.1.5) Requirement already satisfied: pyecharts>=1.2.0 in /usr/local/lib/python3.6/dist-packages (from tensorview) (1.9.0) Requirement already satisfied: linora>=0.9.3 in /usr/local/lib/python3.6/dist-packages (from tensorview) (0.9.3) Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->tensorview) (2.4.7) Requirement already satisfied: numpy>=1.11 in /usr/local/lib/python3.6/dist-packages (from matplotlib->tensorview) (1.18.5) Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.6/dist-packages (from matplotlib->tensorview) (0.10.0) Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->tensorview) (2.8.1) Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->tensorview) (1.3.1) Requirement already satisfied: pillow in /usr/local/lib/python3.6/dist-packages (from pyecharts-snapshot>=0.1.10tensorflow>=2.0.0->tensorview) (7.0.0) Requirement already satisfied: pyppeteer>=0.0.25 in /usr/local/lib/python3.6/dist-packages (from pyecharts-snapshot>=0.1.10tensorflow>=2.0.0->tensorview) (0.2.2) Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.6/dist-packages (from pandas>=0.24.1->tensorview) (2018.9) Requirement already satisfied: jinja2 in /usr/local/lib/python3.6/dist-packages (from pyecharts>=1.2.0->tensorview) (2.11.2) Requirement already satisfied: prettytable in /usr/local/lib/python3.6/dist-packages (from pyecharts>=1.2.0->tensorview) (2.0.0) Requirement already satisfied: simplejson in /usr/local/lib/python3.6/dist-packages (from pyecharts>=1.2.0->tensorview) (3.17.2) Requirement already satisfied: xgboost>=0.81 in /usr/local/lib/python3.6/dist-packages (from linora>=0.9.3->tensorview) (0.90) Requirement already satisfied: tensorflow>=2.0.0rc0 in /usr/local/lib/python3.6/dist-packages (from linora>=0.9.3->tensorview) (2.3.0) Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from cycler>=0.10->matplotlib->tensorview) (1.15.0) Requirement already satisfied: urllib3<2.0.0,>=1.25.8 in /usr/local/lib/python3.6/dist-packages (from pyppeteer>=0.0.25->pyecharts-snapshot>=0.1.10tensorflow>=2.0.0->tensorview) (1.26.2) Requirement already satisfied: appdirs<2.0.0,>=1.4.3 in /usr/local/lib/python3.6/dist-packages (from pyppeteer>=0.0.25->pyecharts-snapshot>=0.1.10tensorflow>=2.0.0->tensorview) (1.4.4) Requirement already satisfied: tqdm<5.0.0,>=4.42.1 in /usr/local/lib/python3.6/dist-packages (from pyppeteer>=0.0.25->pyecharts-snapshot>=0.1.10tensorflow>=2.0.0->tensorview) (4.54.1) Requirement already satisfied: websockets<9.0,>=8.1 in /usr/local/lib/python3.6/dist-packages (from pyppeteer>=0.0.25->pyecharts-snapshot>=0.1.10tensorflow>=2.0.0->tensorview) (8.1) Requirement already satisfied: pyee<8.0.0,>=7.0.1 in /usr/local/lib/python3.6/dist-packages (from pyppeteer>=0.0.25->pyecharts-snapshot>=0.1.10tensorflow>=2.0.0->tensorview) (7.0.4) Requirement already satisfied: MarkupSafe>=0.23 in /usr/local/lib/python3.6/dist-packages (from jinja2->pyecharts>=1.2.0->tensorview) (1.1.1) Requirement already satisfied: setuptools in /usr/local/lib/python3.6/dist-packages (from prettytable->pyecharts>=1.2.0->tensorview) (50.3.2) Requirement already satisfied: wcwidth in /usr/local/lib/python3.6/dist-packages (from prettytable->pyecharts>=1.2.0->tensorview) (0.2.5) Requirement already satisfied: scipy in /usr/local/lib/python3.6/dist-packages (from xgboost>=0.81->linora>=0.9.3->tensorview) (1.4.1) Requirement already satisfied: opt-einsum>=2.3.2 in /usr/local/lib/python3.6/dist-packages (from tensorflow>=2.0.0rc0->linora>=0.9.3->tensorview) (3.3.0) Requirement already satisfied: termcolor>=1.1.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow>=2.0.0rc0->linora>=0.9.3->tensorview) (1.1.0) Requirement already satisfied: tensorflow-estimator<2.4.0,>=2.3.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow>=2.0.0rc0->linora>=0.9.3->tensorview) (2.3.0) Requirement already satisfied: wrapt>=1.11.1 in /usr/local/lib/python3.6/dist-packages (from tensorflow>=2.0.0rc0->linora>=0.9.3->tensorview) (1.12.1) Requirement already satisfied: absl-py>=0.7.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow>=2.0.0rc0->linora>=0.9.3->tensorview) (0.10.0) Requirement already satisfied: h5py<2.11.0,>=2.10.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow>=2.0.0rc0->linora>=0.9.3->tensorview) (2.10.0) Requirement already satisfied: keras-preprocessing<1.2,>=1.1.1 in /usr/local/lib/python3.6/dist-packages (from tensorflow>=2.0.0rc0->linora>=0.9.3->tensorview) (1.1.2) Requirement already satisfied: wheel>=0.26 in /usr/local/lib/python3.6/dist-packages (from tensorflow>=2.0.0rc0->linora>=0.9.3->tensorview) (0.36.1) Requirement already satisfied: tensorboard<3,>=2.3.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow>=2.0.0rc0->linora>=0.9.3->tensorview) (2.3.0) Requirement already satisfied: protobuf>=3.9.2 in /usr/local/lib/python3.6/dist-packages (from tensorflow>=2.0.0rc0->linora>=0.9.3->tensorview) (3.12.4) Requirement already satisfied: grpcio>=1.8.6 in /usr/local/lib/python3.6/dist-packages (from tensorflow>=2.0.0rc0->linora>=0.9.3->tensorview) (1.34.0) Requirement already satisfied: astunparse==1.6.3 in /usr/local/lib/python3.6/dist-packages (from tensorflow>=2.0.0rc0->linora>=0.9.3->tensorview) (1.6.3) Requirement already satisfied: google-pasta>=0.1.8 in /usr/local/lib/python3.6/dist-packages (from tensorflow>=2.0.0rc0->linora>=0.9.3->tensorview) (0.2.0) Requirement already satisfied: gast==0.3.3 in /usr/local/lib/python3.6/dist-packages (from tensorflow>=2.0.0rc0->linora>=0.9.3->tensorview) (0.3.3) Requirement already satisfied: requests<3,>=2.21.0 in /usr/local/lib/python3.6/dist-packages (from tensorboard<3,>=2.3.0->tensorflow>=2.0.0rc0->linora>=0.9.3->tensorview) (2.23.0) Requirement already satisfied: google-auth<2,>=1.6.3 in /usr/local/lib/python3.6/dist-packages (from tensorboard<3,>=2.3.0->tensorflow>=2.0.0rc0->linora>=0.9.3->tensorview) (1.17.2) Requirement already satisfied: werkzeug>=0.11.15 in /usr/local/lib/python3.6/dist-packages (from tensorboard<3,>=2.3.0->tensorflow>=2.0.0rc0->linora>=0.9.3->tensorview) (1.0.1) Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /usr/local/lib/python3.6/dist-packages (from tensorboard<3,>=2.3.0->tensorflow>=2.0.0rc0->linora>=0.9.3->tensorview) (0.4.2) Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /usr/local/lib/python3.6/dist-packages (from tensorboard<3,>=2.3.0->tensorflow>=2.0.0rc0->linora>=0.9.3->tensorview) (1.7.0) Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.6/dist-packages (from tensorboard<3,>=2.3.0->tensorflow>=2.0.0rc0->linora>=0.9.3->tensorview) (3.3.3) Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tensorboard<3,>=2.3.0->tensorflow>=2.0.0rc0->linora>=0.9.3->tensorview) (3.0.4) Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tensorboard<3,>=2.3.0->tensorflow>=2.0.0rc0->linora>=0.9.3->tensorview) (2020.12.5) Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tensorboard<3,>=2.3.0->tensorflow>=2.0.0rc0->linora>=0.9.3->tensorview) (2.10) Requirement already satisfied: cachetools<5.0,>=2.0.0 in /usr/local/lib/python3.6/dist-packages (from google-auth<2,>=1.6.3->tensorboard<3,>=2.3.0->tensorflow>=2.0.0rc0->linora>=0.9.3->tensorview) (4.1.1) Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.6/dist-packages (from google-auth<2,>=1.6.3->tensorboard<3,>=2.3.0->tensorflow>=2.0.0rc0->linora>=0.9.3->tensorview) (0.2.8) Requirement already satisfied: rsa<5,>=3.1.4; python_version >= "3" in /usr/local/lib/python3.6/dist-packages (from google-auth<2,>=1.6.3->tensorboard<3,>=2.3.0->tensorflow>=2.0.0rc0->linora>=0.9.3->tensorview) (4.6) Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.6/dist-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard<3,>=2.3.0->tensorflow>=2.0.0rc0->linora>=0.9.3->tensorview) (1.3.0) Requirement already satisfied: importlib-metadata; python_version < "3.8" in /usr/local/lib/python3.6/dist-packages (from markdown>=2.6.8->tensorboard<3,>=2.3.0->tensorflow>=2.0.0rc0->linora>=0.9.3->tensorview) (3.1.1) Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /usr/local/lib/python3.6/dist-packages (from pyasn1-modules>=0.2.1->google-auth<2,>=1.6.3->tensorboard<3,>=2.3.0->tensorflow>=2.0.0rc0->linora>=0.9.3->tensorview) (0.4.8) Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.6/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard<3,>=2.3.0->tensorflow>=2.0.0rc0->linora>=0.9.3->tensorview) (3.1.0) Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.6/dist-packages (from importlib-metadata; python_version < "3.8"->markdown>=2.6.8->tensorboard<3,>=2.3.0->tensorflow>=2.0.0rc0->linora>=0.9.3->tensorview) (3.4.0)
import sys
import tensorflow as tf
import plotly.graph_objects as go
from ipywidgets import widgets
import numpy as np
from tensorflow.keras import models, layers, losses, optimizers, metrics
import tensorflow_datasets as tf_ds
import tensorview as tv
import matplotlib.pyplot as plt
from pathlib import Path
/usr/local/lib/python3.6/dist-packages/requests/__init__.py:91: RequestsDependencyWarning: urllib3 (1.26.2) or chardet (3.0.4) doesn't match a supported version! RequestsDependencyWarning)
if COLAB:
model_path = Path('/content/drive/My Drive/Colab Notebooks/DsStepByStep')
else:
model_path = Path('model')
batch_size = 100
latent_dim = 100
image_width, image_height, image_channels = 32, 32, 1
mnist_dim = image_width * image_height * image_channels
disc_learning_rate = 0.0002
gen_learning_rate = 0.0002
relu_alpha = 0.01
MNIST dataset is optimized to be stored efficiently: images are closely cropped at 28x28 pixels and stored as 1 byte per pixel (uint8 format). However, to get proper performance we need to modify the input data to insert some padding around and convert the pixel format to float on 32 bits.
def normalize_img(image, label):
"""Normalizes images: `uint8` -> `float32` and pad to 32 x 32"""
image_float = tf.cast(image, tf.float32) / 128. - 1.
image_padded = tf.pad(image_float, [[0, 0], [2, 2], [2, 2], [0, 0]])
return image_padded, label
(ds_train, ds_test) = tf_ds.load('mnist', split=['train', 'test'], batch_size=batch_size, as_supervised=True)
ds_train = ds_train.map(normalize_img)
ds_train = ds_train.cache()
ds_test = ds_test.map(normalize_img)
ds_test = ds_test.cache()
ds_train, ds_test
Downloading and preparing dataset mnist/3.0.1 (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /root/tensorflow_datasets/mnist/3.0.1...
WARNING:absl:Dataset mnist is hosted on GCS. It will automatically be downloaded to your local data directory. If you'd instead prefer to read directly from our public GCS bucket (recommended if you're running on GCP), you can instead pass `try_gcs=True` to `tfds.load` or set `data_dir=gs://tfds-data/datasets`.
HBox(children=(HTML(value='Dl Completed...'), FloatProgress(value=0.0, max=4.0), HTML(value='')))
Dataset mnist downloaded and prepared to /root/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.
(<CacheDataset shapes: ((None, 32, 32, 1), (None,)), types: (tf.float32, tf.int64)>, <CacheDataset shapes: ((None, 32, 32, 1), (None,)), types: (tf.float32, tf.int64)>)
GAN model is built out of a generator and a discriminator:
The generator and discriminator architecture are more or less symmetrical. The generator is increasing the output space dimension step by step using wider and wider layers. The discriminator is similar to other classification networks reducing the input space dimensions down to the binary classification layer.
The "game" is to jointly train the generator and discriminator in order to have the best generator but still being able to detect generated images.
generator = models.Sequential([
layers.Dense(256, input_dim=latent_dim),
layers.LeakyReLU(relu_alpha),
layers.Dropout(0.3),
layers.Dense(512),
layers.LeakyReLU(relu_alpha),
layers.Dense(1024),
layers.LeakyReLU(relu_alpha),
layers.Dense(mnist_dim, activation='tanh'),
layers.Reshape([32, 32, 1]),
], name='generator')
generator.compile()
generator.summary()
Model: "generator" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense (Dense) (None, 256) 25856 _________________________________________________________________ leaky_re_lu (LeakyReLU) (None, 256) 0 _________________________________________________________________ dropout (Dropout) (None, 256) 0 _________________________________________________________________ dense_1 (Dense) (None, 512) 131584 _________________________________________________________________ leaky_re_lu_1 (LeakyReLU) (None, 512) 0 _________________________________________________________________ dense_2 (Dense) (None, 1024) 525312 _________________________________________________________________ leaky_re_lu_2 (LeakyReLU) (None, 1024) 0 _________________________________________________________________ dense_3 (Dense) (None, 1024) 1049600 _________________________________________________________________ reshape (Reshape) (None, 32, 32, 1) 0 ================================================================= Total params: 1,732,352 Trainable params: 1,732,352 Non-trainable params: 0 _________________________________________________________________
discriminator = models.Sequential([
layers.Input(shape=[32, 32, 1]),
layers.Flatten(),
layers.Dropout(0.3),
layers.Dense(1024),
layers.LeakyReLU(relu_alpha),
layers.Dense(512),
layers.LeakyReLU(),
layers.Dense(256),
layers.LeakyReLU(relu_alpha),
layers.Dense(1) # activation='sigmoid'
], name='discriminator')
discriminator.compile()
discriminator.summary()
Model: "discriminator" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= flatten (Flatten) (None, 1024) 0 _________________________________________________________________ dropout_1 (Dropout) (None, 1024) 0 _________________________________________________________________ dense_4 (Dense) (None, 1024) 1049600 _________________________________________________________________ leaky_re_lu_3 (LeakyReLU) (None, 1024) 0 _________________________________________________________________ dense_5 (Dense) (None, 512) 524800 _________________________________________________________________ leaky_re_lu_4 (LeakyReLU) (None, 512) 0 _________________________________________________________________ dense_6 (Dense) (None, 256) 131328 _________________________________________________________________ leaky_re_lu_5 (LeakyReLU) (None, 256) 0 _________________________________________________________________ dense_7 (Dense) (None, 1) 257 ================================================================= Total params: 1,705,985 Trainable params: 1,705,985 Non-trainable params: 0 _________________________________________________________________
Training is alternatively on the distriminator and generator.
The discriminator is trained on a batch made of half genuine images and half trained images.
The generator is trained with its output fed into the discriminator (whose wheights are frozen in this phase).
GAN reputation as difficult to be trained is well deserved and originates in the joint optimization which is similar to a minimax problem (min discrination error, max fidelity of the fakes). As seen below, the noise on the losses and accuracies is high. The main facilitators helping this training are:
epochs = 60
batch_per_epoch = 60000/batch_size
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)
def generator_loss(disc_generated_output):
return loss_object(tf.ones_like(disc_generated_output), disc_generated_output)
def discriminator_loss(disc_real_output, disc_generated_output):
real_loss = loss_object(tf.ones_like(disc_real_output), disc_real_output)
generated_loss = loss_object(tf.zeros_like(disc_generated_output), disc_generated_output)
return real_loss + generated_loss
@tf.function
def train_step(generator, discriminator,
generator_optimizer, discriminator_optimizer,
generator_latent, batch,
epoch):
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
gen_latent = generator_latent()
gen_output = generator(gen_latent, training=True)
disc_real_output = discriminator(batch, training=True)
disc_generated_output = discriminator(gen_output, training=True)
gen_loss = generator_loss(disc_generated_output)
disc_loss = discriminator_loss(disc_real_output, disc_generated_output)
generator_gradients = gen_tape.gradient(gen_loss, generator.trainable_variables)
discriminator_gradients = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
generator_optimizer.apply_gradients(zip(generator_gradients, generator.trainable_variables))
discriminator_optimizer.apply_gradients(zip(discriminator_gradients, discriminator.trainable_variables))
return gen_loss, disc_loss
generator_optimizer = tf.keras.optimizers.Adam(gen_learning_rate, beta_1=0.05)
discriminator_optimizer = tf.keras.optimizers.Adam(disc_learning_rate, beta_1=0.05)
tv_plot = tv.train.PlotMetrics(wait_num=200, columns=2, iter_num=epochs * batch_per_epoch)
def generator_latent():
return tf.random.normal((batch_size, latent_dim), 0, 1)
for epoch in range(epochs):
for train_batch in iter(ds_train):
g_loss, d_loss = train_step(generator, discriminator,
generator_optimizer, discriminator_optimizer,
generator_latent, train_batch[0],
epoch)
# Plot
tv_plot.update({ 'discriminator_loss': d_loss,# 'discriminator_acc': d_acc,
'generator_loss': g_loss, # 'generator_acc': g_acc
})
tv_plot.draw()
gen_latent = generator_latent()
gen_imgs = generator(gen_latent, training=True).numpy()
fig, axes = plt.subplots(8, 8, sharex=True, sharey=True, figsize=(10, 10))
for img, ax in zip(gen_imgs, axes.ravel()):
ax.imshow(img.reshape(image_width, image_height), interpolation='nearest', cmap='gray')
ax.axis('off')
fig.tight_layout()
discriminator.save(model_path / 'mnist_gan_discriminator.h5')
generator.save(model_path / 'mnist_gan_generator.h5')
The generator is not fooling that much the discriminator since the discriminator accuracy is well above 50% and its loss is stable at a low level. However, the generated digits are quite well looking to an human eye.