Generating Images of Clothes Using Deep Convolutional Generative Adversarial Network (DCGAN)

Experiment overview

In this experiment we will generate images of clothing using a Deep Convolutional Generative Adversarial Network (DCGAN). The code is written using the Keras Sequential API with a tf.GradientTape training loop. For training we will be using Fashion MNIST dataset.

A generative adversarial network (GAN) is a class of machine learning frameworks. Two neural networks contest with each other in a game. Two models are trained simultaneously by an adversarial process. A generator ("the artist") learns to create images that look real, while a discriminator ("the art critic") learns to tell real images apart from fakes.

clothes_generation_dcgan.jpg

Importing dependencies

In [49]:
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import math
import datetime
import platform
import imageio
import PIL
import time
import os
import glob
import zipfile

from IPython import display

print('Python version:', platform.python_version())
print('Tensorflow version:', tf.__version__)
print('Keras version:', tf.keras.__version__)
Python version: 3.7.6
Tensorflow version: 2.1.0
Keras version: 2.2.4-tf
In [50]:
# Checking the eager execution availability.
tf.executing_eagerly()
Out[50]:
True

Loading the data

In [51]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
In [52]:
print('x_train.shape: ', x_train.shape)
print('y_train.shape: ', y_train.shape)
print()
print('x_test.shape: ', x_test.shape)
print('y_test.shape: ', y_test.shape)
x_train.shape:  (60000, 28, 28)
y_train.shape:  (60000,)

x_test.shape:  (10000, 28, 28)
y_test.shape:  (10000,)
In [53]:
# Since we don't need test examples we may concatenate both sets
x_train = np.concatenate((x_train, x_test), axis=0)

print('x_train.shape: ', x_train.shape)
x_train.shape:  (70000, 28, 28)
In [54]:
TOTAL_EXAMPLES_NUM = x_train.shape[0]

print('TOTAL_EXAMPLES_NUM: ', TOTAL_EXAMPLES_NUM)
TOTAL_EXAMPLES_NUM:  70000
In [57]:
print('y_train[0] =', y_train[0])
y_train[0] = 9

Here are the map of classes for the dataset according to documentation:

Label Class
0 T-shirt/top
1 Trouser
2 Pullover
3 Dress
4 Coat
5 Sandal
6 Shirt
7 Sneaker
8 Bag
9 Ankle boot
In [58]:
class_names = [
    'T-shirt/top',
    'Trouser',
    'Pullover',
    'Dress',
    'Coat',
    'Sandal',
    'Shirt',
    'Sneaker',
    'Bag',
    'Ankle boot'
]

Analyzing the dataset

In [59]:
print(x_train[0])
[[  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   1   0   0  13  73   0
    0   1   4   0   0   0   0   1   1   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   3   0  36 136 127  62
   54   0   0   0   1   3   4   0   0   3]
 [  0   0   0   0   0   0   0   0   0   0   0   0   6   0 102 204 176 134
  144 123  23   0   0   0   0  12  10   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0 155 236 207 178
  107 156 161 109  64  23  77 130  72  15]
 [  0   0   0   0   0   0   0   0   0   0   0   1   0  69 207 223 218 216
  216 163 127 121 122 146 141  88 172  66]
 [  0   0   0   0   0   0   0   0   0   1   1   1   0 200 232 232 233 229
  223 223 215 213 164 127 123 196 229   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0 183 225 216 223 228
  235 227 224 222 224 221 223 245 173   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0 193 228 218 213 198
  180 212 210 211 213 223 220 243 202   0]
 [  0   0   0   0   0   0   0   0   0   1   3   0  12 219 220 212 218 192
  169 227 208 218 224 212 226 197 209  52]
 [  0   0   0   0   0   0   0   0   0   0   6   0  99 244 222 220 218 203
  198 221 215 213 222 220 245 119 167  56]
 [  0   0   0   0   0   0   0   0   0   4   0   0  55 236 228 230 228 240
  232 213 218 223 234 217 217 209  92   0]
 [  0   0   1   4   6   7   2   0   0   0   0   0 237 226 217 223 222 219
  222 221 216 223 229 215 218 255  77   0]
 [  0   3   0   0   0   0   0   0   0  62 145 204 228 207 213 221 218 208
  211 218 224 223 219 215 224 244 159   0]
 [  0   0   0   0  18  44  82 107 189 228 220 222 217 226 200 205 211 230
  224 234 176 188 250 248 233 238 215   0]
 [  0  57 187 208 224 221 224 208 204 214 208 209 200 159 245 193 206 223
  255 255 221 234 221 211 220 232 246   0]
 [  3 202 228 224 221 211 211 214 205 205 205 220 240  80 150 255 229 221
  188 154 191 210 204 209 222 228 225   0]
 [ 98 233 198 210 222 229 229 234 249 220 194 215 217 241  65  73 106 117
  168 219 221 215 217 223 223 224 229  29]
 [ 75 204 212 204 193 205 211 225 216 185 197 206 198 213 240 195 227 245
  239 223 218 212 209 222 220 221 230  67]
 [ 48 203 183 194 213 197 185 190 194 192 202 214 219 221 220 236 225 216
  199 206 186 181 177 172 181 205 206 115]
 [  0 122 219 193 179 171 183 196 204 210 213 207 211 210 200 196 194 191
  195 191 198 192 176 156 167 177 210  92]
 [  0   0  74 189 212 191 175 172 175 181 185 188 189 188 193 198 204 209
  210 210 211 188 188 194 192 216 170   0]
 [  2   0   0   0  66 200 222 237 239 242 246 243 244 221 220 193 191 179
  182 182 181 176 166 168  99  58   0   0]
 [  0   0   0   0   0   0   0  40  61  44  72  41  35   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]]
In [60]:
plt.figure(figsize=(2, 2))
plt.imshow(x_train[0], cmap=plt.cm.binary)
plt.show()
In [11]:
numbers_to_display = 25
num_cells = math.ceil(math.sqrt(numbers_to_display))
plt.figure(figsize=(8, 8))
for i in range(numbers_to_display):
    plt.subplot(num_cells, num_cells, i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(x_train[i], cmap=plt.cm.binary)
    plt.xlabel(class_names[y_train[i]])
plt.show()

Reshape and normalize data

In [0]:
x_train_reshaped = x_train.reshape(
    x_train.shape[0],
    x_train.shape[1],
    x_train.shape[2],
    1
).astype('float32')
In [13]:
print('x_train_reshaped.shape: ', x_train_reshaped.shape)
x_train_reshaped.shape:  (70000, 28, 28, 1)
In [0]:
# Normalize image pixel values to [-1, 1] range
x_train_normalized = (x_train_reshaped - 127.5) / 127.5
In [65]:
print('Normalized data values:\n')
print(x_train_normalized[0,:,:,0])
Normalized data values:

[[-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -0.99215686 -1.         -1.         -0.8980392  -0.42745098 -1.
  -1.         -0.99215686 -0.96862745 -1.         -1.         -1.
  -1.         -0.99215686 -0.99215686 -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -0.9764706  -1.         -0.7176471   0.06666667 -0.00392157 -0.5137255
  -0.5764706  -1.         -1.         -1.         -0.99215686 -0.9764706
  -0.96862745 -1.         -1.         -0.9764706 ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -0.9529412  -1.         -0.2         0.6         0.38039216  0.05098039
   0.12941177 -0.03529412 -0.81960785 -1.         -1.         -1.
  -1.         -0.90588236 -0.92156863 -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.          0.21568628  0.8509804   0.62352943  0.39607844
  -0.16078432  0.22352941  0.2627451  -0.14509805 -0.49803922 -0.81960785
  -0.39607844  0.01960784 -0.43529412 -0.88235295]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -0.99215686
  -1.         -0.45882353  0.62352943  0.7490196   0.70980394  0.69411767
   0.69411767  0.2784314  -0.00392157 -0.05098039 -0.04313726  0.14509805
   0.10588235 -0.30980393  0.34901962 -0.48235294]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -0.99215686 -0.99215686 -0.99215686
  -1.          0.5686275   0.81960785  0.81960785  0.827451    0.79607844
   0.7490196   0.7490196   0.6862745   0.67058825  0.28627452 -0.00392157
  -0.03529412  0.5372549   0.79607844 -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.          0.43529412  0.7647059   0.69411767  0.7490196   0.7882353
   0.84313726  0.78039217  0.75686276  0.7411765   0.75686276  0.73333335
   0.7490196   0.92156863  0.35686275 -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.          0.5137255   0.7882353   0.70980394  0.67058825  0.5529412
   0.4117647   0.6627451   0.64705884  0.654902    0.67058825  0.7490196
   0.7254902   0.90588236  0.58431375 -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -0.99215686 -0.9764706  -1.
  -0.90588236  0.7176471   0.7254902   0.6627451   0.70980394  0.5058824
   0.3254902   0.78039217  0.6313726   0.70980394  0.75686276  0.6627451
   0.77254903  0.54509807  0.6392157  -0.5921569 ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -0.9529412  -1.
  -0.22352941  0.9137255   0.7411765   0.7254902   0.70980394  0.5921569
   0.5529412   0.73333335  0.6862745   0.67058825  0.7411765   0.7254902
   0.92156863 -0.06666667  0.30980393 -0.56078434]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -0.96862745 -1.         -1.
  -0.5686275   0.8509804   0.7882353   0.8039216   0.7882353   0.88235295
   0.81960785  0.67058825  0.70980394  0.7490196   0.8352941   0.7019608
   0.7019608   0.6392157  -0.2784314  -1.        ]
 [-1.         -1.         -0.99215686 -0.96862745 -0.9529412  -0.94509804
  -0.9843137  -1.         -1.         -1.         -1.         -1.
   0.85882354  0.77254903  0.7019608   0.7490196   0.7411765   0.7176471
   0.7411765   0.73333335  0.69411767  0.7490196   0.79607844  0.6862745
   0.70980394  1.         -0.39607844 -1.        ]
 [-1.         -0.9764706  -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -0.5137255   0.13725491  0.6
   0.7882353   0.62352943  0.67058825  0.73333335  0.70980394  0.6313726
   0.654902    0.70980394  0.75686276  0.7490196   0.7176471   0.6862745
   0.75686276  0.9137255   0.24705882 -1.        ]
 [-1.         -1.         -1.         -1.         -0.85882354 -0.654902
  -0.35686275 -0.16078432  0.48235294  0.7882353   0.7254902   0.7411765
   0.7019608   0.77254903  0.5686275   0.60784316  0.654902    0.8039216
   0.75686276  0.8352941   0.38039216  0.4745098   0.9607843   0.94509804
   0.827451    0.8666667   0.6862745  -1.        ]
 [-1.         -0.5529412   0.46666667  0.6313726   0.75686276  0.73333335
   0.75686276  0.6313726   0.6         0.6784314   0.6313726   0.6392157
   0.5686275   0.24705882  0.92156863  0.5137255   0.6156863   0.7490196
   1.          1.          0.73333335  0.8352941   0.73333335  0.654902
   0.7254902   0.81960785  0.92941177 -1.        ]
 [-0.9764706   0.58431375  0.7882353   0.75686276  0.73333335  0.654902
   0.654902    0.6784314   0.60784316  0.60784316  0.60784316  0.7254902
   0.88235295 -0.37254903  0.1764706   1.          0.79607844  0.73333335
   0.4745098   0.20784314  0.49803922  0.64705884  0.6         0.6392157
   0.7411765   0.7882353   0.7647059  -1.        ]
 [-0.23137255  0.827451    0.5529412   0.64705884  0.7411765   0.79607844
   0.79607844  0.8352941   0.9529412   0.7254902   0.52156866  0.6862745
   0.7019608   0.8901961  -0.49019608 -0.42745098 -0.16862746 -0.08235294
   0.31764707  0.7176471   0.73333335  0.6862745   0.7019608   0.7490196
   0.7490196   0.75686276  0.79607844 -0.77254903]
 [-0.4117647   0.6         0.6627451   0.6         0.5137255   0.60784316
   0.654902    0.7647059   0.69411767  0.4509804   0.54509807  0.6156863
   0.5529412   0.67058825  0.88235295  0.5294118   0.78039217  0.92156863
   0.8745098   0.7490196   0.70980394  0.6627451   0.6392157   0.7411765
   0.7254902   0.73333335  0.8039216  -0.4745098 ]
 [-0.62352943  0.5921569   0.43529412  0.52156866  0.67058825  0.54509807
   0.4509804   0.49019608  0.52156866  0.5058824   0.58431375  0.6784314
   0.7176471   0.73333335  0.7254902   0.8509804   0.7647059   0.69411767
   0.56078434  0.6156863   0.45882353  0.41960785  0.3882353   0.34901962
   0.41960785  0.60784316  0.6156863  -0.09803922]
 [-1.         -0.04313726  0.7176471   0.5137255   0.40392157  0.34117648
   0.43529412  0.5372549   0.6         0.64705884  0.67058825  0.62352943
   0.654902    0.64705884  0.5686275   0.5372549   0.52156866  0.49803922
   0.5294118   0.49803922  0.5529412   0.5058824   0.38039216  0.22352941
   0.30980393  0.3882353   0.64705884 -0.2784314 ]
 [-1.         -1.         -0.41960785  0.48235294  0.6627451   0.49803922
   0.37254903  0.34901962  0.37254903  0.41960785  0.4509804   0.4745098
   0.48235294  0.4745098   0.5137255   0.5529412   0.6         0.6392157
   0.64705884  0.64705884  0.654902    0.4745098   0.4745098   0.52156866
   0.5058824   0.69411767  0.33333334 -1.        ]
 [-0.9843137  -1.         -1.         -1.         -0.48235294  0.5686275
   0.7411765   0.85882354  0.8745098   0.8980392   0.92941177  0.90588236
   0.9137255   0.73333335  0.7254902   0.5137255   0.49803922  0.40392157
   0.42745098  0.42745098  0.41960785  0.38039216  0.3019608   0.31764707
  -0.22352941 -0.54509807 -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -0.6862745  -0.52156866 -0.654902   -0.43529412 -0.6784314
  -0.7254902  -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]]

Creating a batched dataset

In [66]:
SHUFFLE_BUFFER_SIZE = TOTAL_EXAMPLES_NUM
BATCH_SIZE = 1024

TRAINING_STEPS_PER_EPOCH = math.ceil(TOTAL_EXAMPLES_NUM / BATCH_SIZE)

print('BATCH_SIZE: ', BATCH_SIZE)
print('TRAINING_STEPS_PER_EPOCH: ', TRAINING_STEPS_PER_EPOCH)
BATCH_SIZE:  1024
TRAINING_STEPS_PER_EPOCH:  69
In [67]:
train_dataset = tf.data.Dataset.from_tensor_slices(x_train_normalized) \
    .shuffle(SHUFFLE_BUFFER_SIZE) \
    .batch(BATCH_SIZE)

print(train_dataset)
<BatchDataset shapes: (None, 28, 28, 1), types: tf.float32>

Create models

Create generator

In [68]:
def make_generator_model():
    model = tf.keras.Sequential()
    
    # Step 1.
    model.add(tf.keras.layers.Dense(
        units=7*7*256,
        use_bias=False,
        input_shape=(100,)
    ))
    model.add(tf.keras.layers.BatchNormalization())
    model.add(tf.keras.layers.LeakyReLU())
    
    # Step 2.
    model.add(tf.keras.layers.Reshape((7, 7, 256)))
    assert model.output_shape == (None, 7, 7, 256) # None is a batch size.
    
    # Step 3.
    model.add(tf.keras.layers.Conv2DTranspose(
        filters=128,
        kernel_size=(5, 5),
        strides=(1, 1),
        padding='same',
        use_bias=False
    ))
    assert model.output_shape == (None, 7, 7, 128)
    model.add(tf.keras.layers.BatchNormalization())
    model.add(tf.keras.layers.LeakyReLU())
    
    # Step 4.
    model.add(tf.keras.layers.Conv2DTranspose(
        filters=64,
        kernel_size=(5, 5),
        strides=(2, 2),
        padding='same',
        use_bias=False
    ))
    assert model.output_shape == (None, 14, 14, 64)
    model.add(tf.keras.layers.BatchNormalization())
    model.add(tf.keras.layers.LeakyReLU())
    
    # Step 5.
    model.add(tf.keras.layers.Conv2DTranspose(
        filters=1,
        kernel_size=(5, 5),
        strides=(2, 2),
        padding='same',
        use_bias=False,
        activation='tanh'
    ))
    assert model.output_shape == (None, 28, 28, 1)
    
    return model
In [69]:
generator_model = make_generator_model()
In [70]:
generator_model.summary()
Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_2 (Dense)              (None, 12544)             1254400   
_________________________________________________________________
batch_normalization_3 (Batch (None, 12544)             50176     
_________________________________________________________________
leaky_re_lu_5 (LeakyReLU)    (None, 12544)             0         
_________________________________________________________________
reshape_1 (Reshape)          (None, 7, 7, 256)         0         
_________________________________________________________________
conv2d_transpose_3 (Conv2DTr (None, 7, 7, 128)         819200    
_________________________________________________________________
batch_normalization_4 (Batch (None, 7, 7, 128)         512       
_________________________________________________________________
leaky_re_lu_6 (LeakyReLU)    (None, 7, 7, 128)         0         
_________________________________________________________________
conv2d_transpose_4 (Conv2DTr (None, 14, 14, 64)        204800    
_________________________________________________________________
batch_normalization_5 (Batch (None, 14, 14, 64)        256       
_________________________________________________________________
leaky_re_lu_7 (LeakyReLU)    (None, 14, 14, 64)        0         
_________________________________________________________________
conv2d_transpose_5 (Conv2DTr (None, 28, 28, 1)         1600      
=================================================================
Total params: 2,330,944
Trainable params: 2,305,472
Non-trainable params: 25,472
_________________________________________________________________
In [71]:
tf.keras.utils.plot_model(
    generator_model,
    show_shapes=True,
    show_layer_names=True,
    to_file='generator_model.png'
)
Out[71]:
In [72]:
noise = tf.random.normal(shape=[1, 100])

print(noise.numpy())
[[-0.40048575 -0.05899175  1.4562819  -1.0014365  -0.49212027 -0.02965875
  -0.359166   -0.13207589 -0.05057759 -0.7165569   1.348586   -0.5673278
   1.2634926   0.69921404  0.1937053   0.6227805  -0.8695679   1.5103983
  -0.87547743 -0.9266758  -0.94829696 -0.05095936 -0.33741903  1.7615291
  -0.5133022  -1.9463009  -0.47871232 -0.3488199   0.30957273  0.7816733
  -0.05789489 -0.65704983  1.328389   -0.17778188 -0.09234082  1.3932896
  -0.4493648   1.0459919  -0.27539548 -0.5148557   0.6957658   0.2747286
   1.1346655  -0.58931726  1.2654362  -1.3202895   0.05925234  0.57879525
   0.10334445  0.12310242  0.9459113   0.46510884 -1.7625321   1.1537969
   0.95482457  0.07740594  1.101719   -0.94574606  0.90892     1.0484036
  -1.3593727   0.44116554 -1.3168752  -0.6017726   0.22471838  0.524481
  -0.14121588  0.20263085 -0.29945704  0.7611715   0.42068037  0.25163245
   1.2361112  -0.49910775  0.5315448  -0.29282215  0.44396016  0.18037954
  -0.66255087 -0.01424331  0.25285572  0.54898584  0.01249223 -1.8984798
   0.34870186 -0.38278985 -1.2752656  -0.22798996  1.2903773   0.5359575
  -0.8025157   1.1422137   1.4703174  -0.08264713  0.27506617  1.0309777
  -0.48443326  0.76453096  0.57719016  0.5316679 ]]
In [73]:
generated_image = generator_model(noise, training=False)

print('generated_image.shape: ', generated_image.shape)
generated_image.shape:  (1, 28, 28, 1)
In [74]:
print(generated_image[0, :, :, 0].numpy())
[[ 3.62146273e-03 -3.37373791e-03 -2.23259837e-03  4.66620969e-03
   1.15258023e-04  5.01324749e-03 -5.11144102e-03 -4.98591131e-03
   3.25806858e-03  3.24650109e-03 -3.68705858e-03 -7.00298278e-03
   7.17507210e-03  3.48935777e-04 -1.13091886e-03 -3.62540688e-03
   4.15722467e-03 -5.91829943e-04  4.71155299e-03 -3.61183868e-03
   3.84368817e-03  5.68057830e-03  5.30142151e-03 -1.17566239e-03
  -1.84537598e-03 -2.84438924e-04 -1.76289608e-03  6.83842576e-04]
 [-2.32459349e-03 -1.27953419e-03 -3.37371975e-03 -1.02897873e-03
  -1.19031165e-02 -3.47901043e-03  1.06568029e-02  1.61387771e-02
  -1.69485938e-02  1.10559864e-03 -1.05018653e-02  5.57437493e-03
   2.92407884e-03 -4.07211157e-03  5.88989642e-04  2.27185269e-03
   1.62952056e-03  9.98135749e-03  6.50358619e-03  5.50584961e-03
  -3.35985026e-03  5.92968427e-03  8.01353902e-03  1.04573760e-02
  -5.20926062e-03 -3.15492740e-03 -9.45294369e-03 -2.02907156e-03]
 [-1.08655437e-03 -9.76242532e-04  1.22154327e-02  2.89549213e-03
  -2.40238005e-04 -5.81492018e-03  1.95882097e-03 -9.22838750e-04
   1.82223914e-03 -7.69637479e-03 -5.21446159e-03 -1.09772468e-02
  -5.95692173e-03 -2.33915867e-03 -9.31353495e-03  1.22239972e-02
  -6.31380128e-03  1.30199986e-02  1.08500058e-02 -1.09621149e-03
   4.78329835e-03 -1.27854310e-02 -4.04617889e-03  1.11324312e-02
  -2.68187723e-03 -8.11451860e-03  1.80690826e-04 -3.23777436e-03]
 [-4.34456859e-03  1.08021591e-03  9.07736551e-03  6.71621133e-03
   1.94656372e-03 -1.95164117e-03 -1.32440859e-02  1.99190094e-04
   2.16311179e-02 -1.28664607e-02 -7.09737279e-03 -5.89591684e-03
   5.80052426e-03 -4.83242143e-03 -1.52255045e-02  1.20152002e-02
   1.47292307e-02  9.15968511e-03  1.51714450e-02  1.48223871e-02
   1.03338063e-02 -1.48535576e-02 -1.52136211e-03  9.71621554e-03
   4.76413919e-03 -6.04430959e-03  5.14708646e-03 -2.87405419e-04]
 [-3.02930758e-03 -1.19427703e-02 -2.88041658e-03  2.63637368e-04
  -2.47954595e-04 -2.68383650e-03 -8.25305842e-03  4.87447809e-03
   3.87723232e-03  1.71127636e-02  1.76618434e-02 -3.68945720e-03
  -1.29244069e-03  4.67234664e-03 -1.35340896e-02  1.63192730e-02
   1.62246879e-02 -3.82160139e-03 -1.26241334e-02 -6.73409458e-03
  -2.56166025e-03 -1.48072299e-02  9.48595873e-04 -5.34741674e-03
   5.62866312e-03  7.48308841e-03  4.01935400e-03 -3.19045852e-03]
 [-3.48949130e-03 -9.35084280e-03 -8.24388862e-03 -1.24650507e-03
   2.93476880e-03 -2.71880403e-02 -5.54847671e-03  2.81886011e-02
  -9.42054112e-03 -6.17643725e-03 -3.36366938e-03 -3.17791626e-02
   1.46224005e-02 -8.51377100e-03 -5.54802548e-03  4.99748439e-02
  -1.47969567e-03  1.36440508e-02  1.15090255e-02  6.00805460e-03
  -2.44427957e-02 -2.54061539e-03 -1.94960111e-03  1.77331921e-02
  -1.45183373e-02 -9.58510581e-03 -2.48887925e-03  6.41941139e-03]
 [ 4.38543037e-03 -2.23167450e-03  1.44272868e-03 -3.01642832e-03
  -1.22719235e-03 -7.65156560e-03 -1.68482959e-03 -1.07503235e-02
  -1.77560211e-03 -1.05693024e-02  8.33574310e-03 -2.35626772e-02
  -1.26337176e-02 -1.56643074e-02 -8.19774810e-04 -5.32189431e-03
  -6.26925472e-03 -9.05521493e-03 -1.19840242e-02  4.97127825e-04
  -1.48960864e-02 -2.39627250e-02 -5.92306256e-03  1.59224458e-02
  -1.52227646e-02 -1.00052310e-02  6.09163009e-03 -2.11464837e-02]
 [ 2.21529067e-03  1.42816352e-02  5.81281027e-03 -3.23327188e-03
   5.38593065e-03 -1.07889320e-03 -2.99738417e-03  2.65213521e-03
  -6.27323193e-03  2.66653467e-02  2.00978648e-02  1.15421542e-03
   1.33487210e-02  3.74643435e-03  3.68786044e-03  2.31483839e-02
   6.19633729e-03 -1.05314497e-02 -2.16631126e-02  3.11499853e-02
  -3.53607279e-03  1.51188718e-02 -1.13426456e-02  4.78743464e-02
   2.75963675e-02 -4.74607712e-03  2.98028695e-03 -1.47000421e-02]
 [-9.75484552e-04 -1.87517386e-02  6.39571156e-03 -1.09980404e-02
  -6.36178348e-03  5.97111508e-03 -8.37849453e-04 -5.37107652e-03
  -1.53775848e-02 -1.13707948e-02  1.30695766e-02 -1.22213075e-02
   7.32243853e-03  4.31752158e-03  3.81027418e-03 -1.11757852e-02
   1.48883620e-02  5.91330696e-03  6.45949878e-03 -9.61624738e-03
   1.23961624e-02  3.12477164e-03  1.78550836e-02 -1.53378481e-02
   1.68247763e-02  2.35145888e-03  9.66674462e-03 -1.26689449e-02]
 [-4.53034788e-03 -5.13586868e-03 -1.09016187e-02  2.11124518e-03
   1.03256200e-02  1.92142017e-02  1.01808663e-02  7.77163729e-03
  -1.13144163e-02 -2.06057518e-03  1.64466747e-03 -1.26656434e-02
  -1.25816762e-02 -2.35923030e-03 -1.24072339e-02  2.99524952e-04
   1.59975723e-03 -4.04913984e-02 -1.04828384e-02  2.14716420e-02
  -1.75706518e-03  6.97366986e-03 -6.76607294e-03  5.74283442e-03
  -9.74454393e-04  3.70335430e-02  3.50344516e-02 -2.19072085e-02]
 [-6.33765059e-03 -7.24817067e-03  6.53755618e-03  1.95069844e-03
   2.49424344e-03  9.76337679e-03  2.74362997e-03 -4.01810836e-03
  -6.89734798e-03  7.55376508e-03  4.68930113e-04 -1.00104986e-02
   1.61441267e-02  2.13299580e-02  8.37038644e-03 -1.22260777e-02
  -5.00570657e-03 -1.10691898e-02 -1.48134297e-02  3.67610437e-06
  -7.90168159e-03 -2.45886985e-02  2.16707699e-02  1.30667910e-02
  -2.27732006e-02 -2.44000666e-02  4.69451724e-03 -4.65407036e-03]
 [-3.97500495e-04  5.59741352e-03 -1.26262533e-03  1.38056604e-02
   7.69853080e-03  1.02469148e-02  2.11865967e-03  2.34880578e-02
  -1.36010703e-02  2.41397936e-02  1.47145968e-02  2.73363153e-03
   2.54591051e-02  1.31826168e-02  5.61611541e-03  2.16647629e-02
   1.68988332e-02 -5.91650896e-04 -1.88403949e-02  3.47850434e-02
  -1.03903115e-02  1.19025661e-02  2.24366654e-02  1.25314258e-02
  -9.98511538e-03 -4.03224770e-03 -2.05838848e-02 -3.67625132e-02]
 [-6.81647705e-03 -6.81502337e-04 -1.02212690e-02 -1.06686605e-02
  -1.42854196e-03  6.24759961e-03 -1.24904374e-02 -4.81363852e-03
   1.33938845e-02  2.21875357e-03 -1.23913633e-02  3.41113540e-03
   1.73668731e-02 -5.56474971e-03 -1.62793212e-02 -2.01927647e-02
  -1.38299596e-02 -2.20136046e-02 -1.97079289e-03 -2.30370704e-02
   1.23280697e-02  1.55290756e-02 -9.70291439e-03 -4.64379601e-03
   6.73847552e-03  2.43831128e-02  3.51902819e-03 -2.45220028e-02]
 [ 6.70772512e-03 -5.60176698e-03 -1.80406333e-03 -1.35467006e-02
  -1.31471567e-02 -1.45529835e-02  9.33896867e-04 -2.44586216e-03
   2.89507862e-02  1.62595250e-02  8.47978052e-03 -1.79202054e-02
   4.06631036e-03 -4.48405147e-02 -4.93633561e-03  2.67268904e-03
   7.71781197e-04 -2.08309945e-02 -2.98636500e-02 -2.63190269e-02
  -3.18020917e-02 -8.53830110e-03 -1.12102060e-02 -3.05246729e-02
  -2.17090603e-02 -2.07713293e-03  2.57720239e-03 -7.04809674e-04]
 [-1.30386988e-03 -1.52483920e-03  2.78399838e-03 -6.23396947e-04
  -7.02505186e-03  2.66378094e-02 -9.57542099e-03  3.94407572e-04
   9.45235137e-03 -3.91205922e-02 -2.20355834e-03  3.01534543e-03
  -7.43860006e-03 -2.32779118e-03 -7.99226016e-03  2.78095668e-03
   5.68673387e-03  4.55740327e-03 -2.12263670e-02 -1.51697071e-02
  -1.48749324e-02 -1.68445744e-02 -1.43898488e-03 -1.12640904e-02
  -1.93291102e-02 -2.06826311e-02 -4.39894153e-03 -1.58750862e-02]
 [ 8.92661221e-04  2.87332553e-02  5.50161488e-03  2.76262425e-02
   9.34928306e-04  1.70804355e-02  9.89066157e-03  2.29741875e-02
  -2.80773058e-03  2.93803997e-02 -4.72826976e-03  1.22490549e-03
   1.87103860e-02  3.96030135e-02 -1.48638757e-02 -6.88467501e-03
   4.44180798e-03  9.90787428e-03  1.96425766e-02  2.99594030e-02
   1.13659725e-02 -9.46089742e-04  1.61536806e-03  2.01995093e-02
   2.76264064e-02  1.35942772e-02 -1.88921001e-02  2.23646569e-03]
 [ 5.96996909e-03  3.21599364e-04 -2.40466148e-02  2.00503990e-02
  -1.41588096e-02 -1.20846624e-03 -3.94202815e-03 -4.52067005e-03
   1.48022855e-02 -5.81123924e-04 -7.56723946e-03 -2.56106053e-02
  -2.38823914e-03  1.01325503e-02  1.22936796e-02  4.06868234e-02
   5.15955733e-03  4.25607292e-03 -1.10634714e-02 -2.30548475e-02
  -3.66645493e-03  3.35261784e-03 -7.60676712e-03 -8.83055013e-03
  -5.17971860e-03 -1.00199468e-02  7.56571535e-03 -3.23796133e-03]
 [-1.83949678e-03 -8.92570149e-03 -1.32302679e-02 -9.34507325e-03
   8.84845294e-03  8.96911044e-03  6.63216552e-03  1.58976465e-02
  -6.10901276e-03  9.65222716e-04 -5.26000001e-03  2.36379020e-02
   1.09209521e-02  1.68258753e-02  2.02850197e-02 -2.11746078e-02
   3.50818895e-02  4.76629473e-03  3.85219534e-03 -7.08658341e-03
  -1.38495686e-02  1.56487878e-02 -7.85811152e-03  2.36311066e-03
  -2.31542159e-02  5.98499179e-03 -2.21874677e-02  7.89724570e-03]
 [ 2.49292911e-03 -1.22126425e-03  2.48686783e-02  5.92503930e-04
  -1.04724327e-02 -1.32915620e-02  1.57627966e-02 -2.26045437e-02
  -1.77267920e-02 -4.18708995e-02 -1.75703075e-02 -1.85578279e-02
  -1.01273982e-02 -1.36248032e-02  9.71704908e-03 -2.17672940e-02
  -2.45038588e-02 -8.22731759e-03 -3.20212208e-02  1.86296720e-02
  -2.20870040e-03 -6.29124790e-03 -3.48552503e-03  1.67474449e-02
   4.64327773e-03  6.60257973e-03  5.24758128e-03 -1.41102597e-02]
 [-3.59580526e-03  2.54093725e-02 -4.23988618e-04  1.20853363e-02
   1.81288517e-03  1.16510382e-02  1.78205734e-03 -3.17027746e-03
  -5.59030659e-03  8.83010961e-03 -4.48579621e-03  5.38882939e-03
   2.30905674e-02  1.38789257e-02  2.89347838e-03  5.11234975e-04
   8.46187305e-03 -6.83443667e-03  2.63330564e-02  5.61695285e-02
   2.83867344e-02  2.13272609e-02  1.86818764e-02  2.40119528e-02
   4.12550941e-02  1.71269625e-02 -1.16232736e-02 -1.12093156e-02]
 [-5.26016229e-04 -8.02509487e-03 -1.78317726e-03  1.26240542e-02
   8.84521101e-03 -7.15629384e-03  4.48989728e-03  1.50084337e-02
   3.44703579e-03  2.97801523e-03  1.95164531e-02  2.66057410e-04
   8.63145944e-03 -1.68882143e-02  7.61054363e-03  1.93403792e-02
  -9.89504531e-03 -2.29283813e-02 -4.38777643e-04  1.24670612e-02
   1.84697332e-03  1.63523685e-02 -2.33178455e-02 -4.12332732e-03
  -2.49453937e-03  4.16569673e-02 -6.18983945e-03  3.15394904e-03]
 [ 2.73987162e-03 -5.72156301e-03  4.67017666e-03 -3.71920830e-03
  -1.73664801e-02 -8.21895618e-03 -9.91616305e-03  6.70508016e-03
  -6.69953274e-03 -1.69223372e-03  1.84470660e-03  3.59362410e-03
   1.25817265e-02 -2.89662275e-02  2.10328940e-02 -4.60524764e-03
   1.11880424e-02 -2.16246583e-03  9.07421485e-03  4.56262659e-03
   1.82753647e-04 -3.30339260e-02  8.84283800e-03 -4.15881164e-03
  -1.63521711e-02 -3.97552326e-02  1.58415325e-02  3.19608976e-03]
 [ 2.95576709e-03 -1.34218726e-02  2.30270363e-02 -9.30730719e-03
  -2.18246076e-02 -1.60645489e-02  5.29193785e-03  7.57779853e-05
  -2.87380698e-03  1.59277150e-03  9.97675862e-03 -1.49989566e-02
   6.64960174e-03 -3.60953957e-02  1.10111600e-02  8.14057013e-04
  -7.30547169e-03 -2.98824403e-02 -1.86987207e-04 -1.52166411e-02
  -1.55169964e-02 -3.13213188e-03 -1.91632484e-03 -7.88594875e-03
   6.13504043e-03 -2.98424568e-02 -8.94748140e-03 -2.02390347e-02]
 [ 1.76868285e-03  1.23285577e-02 -3.77325248e-03  1.73113104e-02
  -3.51235792e-02  2.40283497e-02 -3.46030295e-03  3.47760394e-02
   1.39564648e-03  2.72532441e-02 -1.16243036e-02 -5.48969116e-03
  -1.43558974e-03  3.63925993e-02  2.29507256e-02  2.23198514e-02
   8.72469880e-03  4.75151688e-02  1.37448986e-03  4.37433040e-03
   2.42167786e-02 -1.66251697e-02 -1.14183633e-04  6.21859496e-03
   3.36538441e-02 -1.27480784e-02 -1.96997309e-03  5.11262275e-04]
 [ 3.15447431e-03 -9.60949156e-03  2.53796461e-03  1.22282030e-02
  -3.31321498e-03  1.04366830e-02  9.54547897e-03  1.72090158e-03
  -5.68748277e-04  8.43942817e-03  5.74815262e-04  3.61018814e-02
   1.43055217e-02  1.49328273e-03 -1.61119234e-02  9.13903210e-03
   3.24818306e-02 -1.55513333e-02  4.50196257e-03 -7.38349790e-03
   3.01722754e-02 -3.80781991e-03  1.55264568e-02 -4.45965398e-03
  -5.50039054e-04 -4.53713071e-03 -4.82875621e-03 -1.75618660e-02]
 [-1.39705255e-03 -9.56590381e-03 -1.24372635e-02 -7.42626889e-03
  -1.31014327e-03  1.31019633e-02 -2.14343797e-02 -1.36958398e-02
  -2.04797834e-02 -2.22509820e-02 -1.76416654e-02  2.37367339e-02
   1.06160017e-02  1.56296641e-02  1.81274815e-03 -1.02035957e-03
  -2.93547707e-03 -1.68886241e-02 -1.78066101e-02  2.42912993e-02
   1.25687057e-02 -5.75363776e-03 -3.67524044e-04  1.31052989e-03
   2.40805000e-03 -2.77734222e-03  8.07198603e-03  1.07840775e-02]
 [-2.46818783e-03  5.46638342e-03 -2.33211252e-03  3.60420416e-03
  -4.41368949e-03  1.26605816e-02  1.38953701e-03 -2.76276818e-03
  -4.74420562e-03 -3.86390765e-03 -1.23261986e-02 -1.77381169e-02
  -1.25269650e-03 -1.33819750e-03  1.38756512e-02 -1.65214879e-03
  -1.91068149e-03 -1.06625482e-02  5.72940707e-03  2.26340070e-03
   8.95494036e-03 -4.09065047e-03  1.84036735e-02 -1.92423984e-02
  -1.22753307e-02 -1.53727038e-02 -3.91660351e-03 -3.69235151e-03]
 [ 1.52455550e-03  4.71905223e-04  2.51185521e-03  3.77575343e-04
   3.94193502e-03 -2.46343319e-03  1.21627003e-03  1.43145500e-02
   2.26958841e-02 -2.91464501e-03 -2.12998893e-02 -1.22428909e-02
  -4.66084020e-04  9.76727158e-03 -1.22103025e-03  2.02955268e-02
   1.04613770e-02  8.11114442e-03 -4.82826121e-03  2.12394483e-02
   2.55176052e-02 -1.36667909e-02 -1.34665845e-02  6.25952985e-03
   7.91868195e-03 -2.37570284e-03 -1.96220633e-02  2.20343564e-03]]
In [75]:
plt.figure(figsize=(2, 2))
plt.imshow(generated_image[0, :, :, 0], cmap=plt.cm.binary)
Out[75]:
<matplotlib.image.AxesImage at 0x1315d7450>

Create discriminator

The model will be trained to output positive values for real images, and negative values for fake images.

In [76]:
def make_discriminator_model():
    model = tf.keras.Sequential()
    
    # Step 1.
    model.add(tf.keras.layers.Conv2D(
        filters=64,
        kernel_size=(5, 5),
        strides=(2, 2),
        padding='same',
        input_shape=[28, 28, 1]
    ))
    model.add(tf.keras.layers.LeakyReLU())
    model.add(tf.keras.layers.Dropout(0.3))
    
    # Step 2.
    model.add(tf.keras.layers.Conv2D(
        filters=128,
        kernel_size=(5, 5),
        strides=(2, 2),
        padding='same'
    ))
    model.add(tf.keras.layers.LeakyReLU())
    model.add(tf.keras.layers.Dropout(0.3))
    
    # Step 3.
    model.add(tf.keras.layers.Flatten())
    
    # Real vs Fake
    model.add(tf.keras.layers.Dense(1))
    
    return model
In [78]:
discriminator_model = make_discriminator_model()
In [79]:
discriminator_model.summary()
Model: "sequential_4"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_4 (Conv2D)            (None, 14, 14, 64)        1664      
_________________________________________________________________
leaky_re_lu_10 (LeakyReLU)   (None, 14, 14, 64)        0         
_________________________________________________________________
dropout_4 (Dropout)          (None, 14, 14, 64)        0         
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 7, 7, 128)         204928    
_________________________________________________________________
leaky_re_lu_11 (LeakyReLU)   (None, 7, 7, 128)         0         
_________________________________________________________________
dropout_5 (Dropout)          (None, 7, 7, 128)         0         
_________________________________________________________________
flatten_2 (Flatten)          (None, 6272)              0         
_________________________________________________________________
dense_4 (Dense)              (None, 1)                 6273      
=================================================================
Total params: 212,865
Trainable params: 212,865
Non-trainable params: 0
_________________________________________________________________
In [80]:
tf.keras.utils.plot_model(
    discriminator_model,
    show_shapes=True,
    show_layer_names=True,
    to_file='discriminator_model.png'
)
Out[80]:
In [81]:
dicision = discriminator_model(generated_image)

print(dicision)
tf.Tensor([[0.00100495]], shape=(1, 1), dtype=float32)

Define the loss and optimizers

Discriminator loss

This method quantifies how well the discriminator is able to distinguish real images from fakes. It compares the discriminator's predictions on real images to an array of 1s, and the discriminator's predictions on fake (generated) images to an array of 0s.

In [82]:
def discriminator_loss(real_output, fake_output):
    cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
    
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    
    total_loss = real_loss + fake_loss
    
    return total_loss
In [83]:
# Test discriminator loss function.
output_combinations = [
     # REAL  #FAKE
    ([-1.],  [1.]),
    ([1.],   [-1.]),  
    ([1.],   [0.]),
    ([10.],  [-1.]),                       
]
for (real_output, fake_output) in output_combinations:
    loss = discriminator_loss(real_output, fake_output).numpy()
    print('Discriminator loss for:', real_output, fake_output)
    print('  REAL output:', real_output)
    print('  FAKE output:', fake_output)
    print('  loss: ', loss)
    print()
Discriminator loss for: [-1.0] [1.0]
  REAL output: [-1.0]
  FAKE output: [1.0]
  loss:  2.6265235

Discriminator loss for: [1.0] [-1.0]
  REAL output: [1.0]
  FAKE output: [-1.0]
  loss:  0.6265234

Discriminator loss for: [1.0] [0.0]
  REAL output: [1.0]
  FAKE output: [0.0]
  loss:  1.0064089

Discriminator loss for: [10.0] [-1.0]
  REAL output: [10.0]
  FAKE output: [-1.0]
  loss:  0.31330708

Generator loss

The generator's loss quantifies how well it was able to trick the discriminator. Intuitively, if the generator is performing well, the discriminator will classify the fake images as real (or 1). Here, we will compare the discriminators decisions on the generated images to an array of 1s.

In [84]:
def generator_loss(fake_output):
    cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
    loss = cross_entropy(tf.ones_like(fake_output), fake_output)
    return loss
In [85]:
# Test generator loss function.
print('Generator loss for >1: ', generator_loss([5.]).numpy())
print('Generator loss for =0: ', generator_loss([0.]).numpy())
Generator loss for >1:  0.0067153485
Generator loss for =0:  0.6931472

Optimizers

In [86]:
generator_optimizer = tf.keras.optimizers.Adam(
    learning_rate=0.0001
)

discriminator_optimizer = tf.keras.optimizers.Adam(
    learning_rate=0.0001
)

Checkpoints

In [0]:
checkpoint_dir = './tmp/ckpt'
checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt')
checkpoint = tf.train.Checkpoint(
    generator_optimizer=generator_optimizer,
    discriminator_optimizer=discriminator_optimizer,
    generator_model=generator_model,
    discriminator_model=discriminator_model
)

Training

In [0]:
EPOCHS = 100
noise_dim = 100
num_examples_to_generate = 16

# We will reuse this seed overtime (so it's easier)
# to visualize progress in the animated GIF)
input_noise_seed = tf.random.normal([num_examples_to_generate, noise_dim])

The training loop begins with generator receiving a random seed as input. That seed is used to produce an image. The discriminator is then used to classify real images (drawn from the training set) and fakes images (produced by the generator). The loss is calculated for each of these models, and the gradients are used to update the generator and discriminator.

In [87]:
# This `tf.function` annotation causes the function to be "compiled".
# @tf.function
def train_step(real_images):
    training_history = {
        'discriminator': {
            'loss': None
        },
        'generator': {
            'loss': None
        }
    }

    # Generate input noise.
    noise_images = tf.random.normal([BATCH_SIZE, noise_dim])
    
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        # Generate fake images.
        generated_images = generator_model(
            noise_images,
            training=True
        )
        
        # Detect fake and real images.
        real_output = discriminator_model(
            real_images,
            training=True
        )
        fake_output = discriminator_model(
            generated_images,
            training=True
        )
        
        # Calculate losses.
        gen_loss = generator_loss(fake_output)
        disc_loss = discriminator_loss(
            real_output,
            fake_output
        )

        training_history['discriminator']['loss'] = disc_loss.numpy()
        training_history['generator']['loss'] = gen_loss.numpy()
    
    # Calculate gradients.
    gradients_of_generator = gen_tape.gradient(
        gen_loss,
        generator_model.trainable_variables
    )
    gradients_of_discriminator = disc_tape.gradient(
        disc_loss,
        discriminator_model.trainable_variables
    )
    
    # Do gradient step.
    generator_optimizer.apply_gradients(zip(
        gradients_of_generator,
        generator_model.trainable_variables
    ))
    discriminator_optimizer.apply_gradients(zip(
        gradients_of_discriminator,
        discriminator_model.trainable_variables
    ))

    return training_history
In [88]:
def train(dataset, epochs, start_epoch=0):
    print('Start training...')
    
    training_history = {
        'discriminator': {
            'loss': []
        },
        'generator': {
            'loss': []
        }
    }

    for epoch in range(epochs)[start_epoch:]:
        print('Start epoch #{} ({} steps)...'.format(epoch + 1, TRAINING_STEPS_PER_EPOCH))
        
        start = time.time()
        
        step = 0
        for image_batch in dataset:
            step += 1
            
            # display.clear_output(wait=True)
            # show_progress(step, TRAINING_STEPS_PER_EPOCH)
            # generate_and_save_images(
            #     generator_model,
            #     epoch + 1,
            #     input_noise_seed,
            #     save=False
            # )
            
            training_step_history = train_step(image_batch)

            discriminator_step_loss = training_step_history['discriminator']['loss']
            generator_step_loss = training_step_history['generator']['loss']

        training_history['discriminator']['loss'].append(discriminator_step_loss)
        training_history['generator']['loss'].append(generator_step_loss)
            
        # Produce images for the GIF as we go.
        display.clear_output(wait=True)
        generate_and_save_images(
            generator_model,
            epoch + 1,
            input_noise_seed
        )
        
        # Save the model every 10 epochs.
        if (epoch + 1) % 10 == 0:
            checkpoint.save(file_prefix=checkpoint_prefix)
            
        print('Time for epoch #{} is {:.2f}s'.format(epoch + 1, time.time() - start))
        print('Discriminator loss: {:.4f}'.format(discriminator_step_loss))
        print('Generator loss: {:.4f}'.format(generator_step_loss))

    return training_history
In [89]:
def show_progress(current_step, total_steps):
    length_divider = 2
    progress = math.floor(current_step * 100 / total_steps)
    done_steps = progress
    left_steps = 100 - done_steps
    done_dots = ''.join(['◼︎' for step in range(math.floor(done_steps / length_divider))])
    left_dors = ''.join(['・' for step in range(math.floor(left_steps / length_divider))])
    print(f'{current_step}/{total_steps}: {done_dots}{left_dors}')
In [90]:
# Test progress function.
show_progress(15, 68)
15/68: ◼︎◼︎◼︎◼︎◼︎◼︎◼︎◼︎◼︎◼︎◼︎・・・・・・・・・・・・・・・・・・・・・・・・・・・・・・・・・・・・・・・
In [91]:
IMAGES_PREVIEW_PATH = 'tmp/imgs/'

if not os.path.exists(IMAGES_PREVIEW_PATH):
    os.makedirs(IMAGES_PREVIEW_PATH)
In [92]:
def generate_and_save_images(model, epoch, test_input, save=True):
    # Notice `training` is set to False.
    # This is so all layers run in inference mode (batchnorm).
    predictions = model(test_input, training=False)
    
    fig_dimension = int(math.sqrt(num_examples_to_generate))
    plt.figure(figsize=(8, 8))
    fig = plt.figure(figsize=(fig_dimension, fig_dimension))
    
    for i in range(predictions.shape[0]):
        plt.subplot(fig_dimension, fig_dimension, i+1)
        plt.imshow(
            predictions[i, :, :, 0] * 127.5 + 127.5,
            cmap=plt.cm.binary
        )
        plt.axis('off')
        
    if save:
        plt.savefig('{}image_at_epoch_{:04d}.png'.format(IMAGES_PREVIEW_PATH, epoch))
        
    plt.show()
In [93]:
if not 'training_history' in locals():
    training_history = {
        'discriminator': {
            'loss': []
        },
        'generator': {
            'loss': []
        }
    }
In [0]:
training_session_num = 1
start_epoch = training_session_num * EPOCHS
epochs_num = start_epoch + EPOCHS
In [168]:
training_history_current = train(
    train_dataset,
    epochs=epochs_num,
    start_epoch=start_epoch
)
<Figure size 576x576 with 0 Axes>
Time for epoch #2700 is 10.57s
Discriminator loss: 1.2439
Generator loss: 0.9003

Analyzing training history

In [0]:
training_history['generator']['loss'] += training_history_current['generator']['loss']
training_history['discriminator']['loss'] += training_history_current['discriminator']['loss']
In [0]:
def render_training_history(training_history):
    generator_loss = training_history['generator']['loss']
    discriminator_loss = training_history['discriminator']['loss']

    plt.figure(figsize=(14, 4))

    plt.subplot(1, 2, 1)
    plt.title('Generator Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.plot(generator_loss, label='Generator loss')
    plt.legend()
    plt.grid(linestyle='--', linewidth=1, alpha=0.5)

    plt.subplot(1, 2, 2)
    plt.title('Discriminator Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.plot(discriminator_loss, label='Discriminator loss')
    plt.legend()
    plt.grid(linestyle='--', linewidth=1, alpha=0.5)

    plt.show()
In [170]:
render_training_history(training_history)
In [0]:
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

Save models

In [0]:
generator_model.save('generator_model.h5', save_format='h5')
discriminator_model.save('discriminator_model.h5', save_format='h5')
In [107]:
# Restore models from files if needed.
# generator_model.load_weights('./generator_model.h5')
# discriminator_model.load_weights('./discriminator_model.h5')

Zip all preview images

In [96]:
def zip_image_previews():
    images_previews_path = IMAGES_PREVIEW_PATH
    images_previews_zip_name = 'images_previews.zip'

    zipped_files_num = 0
    with zipfile.ZipFile(images_previews_zip_name, mode='w') as zip_obj:
        for folder_name, subfolders, filenames in os.walk(images_previews_path):
            for filename in filenames:
                zipped_files_num += 1
                file_path = os.path.join(folder_name, filename)
                zip_obj.write(file_path, os.path.basename(file_path))
    print('Zipped {} files to '.format(zipped_files_num), images_previews_zip_name)
In [178]:
zip_image_previews()
Zipped 2600 files to  images_previews.zip

Trying models

In [118]:
test_examples_num = 10

noise_images = tf.random.normal([test_examples_num, 100])
generated_images = generator_model(noise_images, training=False)

for example_num in range(test_examples_num):
    plt.figure(figsize=(3, 3))

    plt.subplot(1, 2, 1)
    plt.imshow(np.reshape(noise_images[example_num], (10, 10)), cmap=plt.cm.binary)

    plt.subplot(1, 2, 2)
    plt.imshow(generated_images[example_num, :, :, 0], cmap=plt.cm.binary)

Create a GIF

In [97]:
# Display a single image using the epoch number
def display_image(epoch_no):
    return PIL.Image.open('{}image_at_epoch_{:04d}.png'.format(IMAGES_PREVIEW_PATH, epoch_no))
In [99]:
display_image(EPOCHS)
Out[99]:
In [123]:
anim_file = 'clothes_generation_dcgan.gif'

with imageio.get_writer(anim_file, mode='I') as writer:
    filenames = glob.glob(IMAGES_PREVIEW_PATH + 'image*.png')
    filenames = sorted(filenames)
    last = -1
    for i, filename in enumerate(filenames):
        frame = 2*(i**0.5)
        if round(frame) > round(last):
            last = frame
        else:
            continue
        image = imageio.imread(filename)
        writer.append_data(image)
    image = imageio.imread(filename)
    writer.append_data(image)

display.Image(filename=anim_file)
Out[123]:
<IPython.core.display.Image object>

Clothes Generation with DCGAN demo

Converting the model to web-format

To use this model on the web we need to convert it into the format that will be understandable by tensorflowjs. To do so we may use tfjs-converter as following:

tensorflowjs_converter --input_format keras \
  ./experiments/clothes_generation_dcgan/generator_model.h5 \
  ./demos/public/models/clothes_generation_dcgan

You find this experiment in the Demo app and play around with it right in you browser to see how the model performs in real life.