- 🤖 See full list of Machine Learning Experiments on GitHub
- ▶️ Interactive Demo: try this model and other machine learning experiments in action
In this experiment we will generate images of clothing using a Deep Convolutional Generative Adversarial Network (DCGAN). The code is written using the Keras Sequential API with a tf.GradientTape training loop. For training we will be using Fashion MNIST dataset.
A generative adversarial network (GAN) is a class of machine learning frameworks. Two neural networks contest with each other in a game. Two models are trained simultaneously by an adversarial process. A generator ("the artist") learns to create images that look real, while a discriminator ("the art critic") learns to tell real images apart from fakes.
Inspired by: Deep Convolutional Generative Adversarial Network tutorial.
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import math
import datetime
import platform
import imageio
import PIL
import time
import os
import glob
import zipfile
from IPython import display
print('Python version:', platform.python_version())
print('Tensorflow version:', tf.__version__)
print('Keras version:', tf.keras.__version__)
Python version: 3.7.6 Tensorflow version: 2.1.0 Keras version: 2.2.4-tf
# Checking the eager execution availability.
tf.executing_eagerly()
True
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
print('x_train.shape: ', x_train.shape)
print('y_train.shape: ', y_train.shape)
print()
print('x_test.shape: ', x_test.shape)
print('y_test.shape: ', y_test.shape)
x_train.shape: (60000, 28, 28) y_train.shape: (60000,) x_test.shape: (10000, 28, 28) y_test.shape: (10000,)
# Since we don't need test examples we may concatenate both sets
x_train = np.concatenate((x_train, x_test), axis=0)
print('x_train.shape: ', x_train.shape)
x_train.shape: (70000, 28, 28)
TOTAL_EXAMPLES_NUM = x_train.shape[0]
print('TOTAL_EXAMPLES_NUM: ', TOTAL_EXAMPLES_NUM)
TOTAL_EXAMPLES_NUM: 70000
print('y_train[0] =', y_train[0])
y_train[0] = 9
Here are the map of classes for the dataset according to documentation:
Label | Class |
---|---|
0 | T-shirt/top |
1 | Trouser |
2 | Pullover |
3 | Dress |
4 | Coat |
5 | Sandal |
6 | Shirt |
7 | Sneaker |
8 | Bag |
9 | Ankle boot |
class_names = [
'T-shirt/top',
'Trouser',
'Pullover',
'Dress',
'Coat',
'Sandal',
'Shirt',
'Sneaker',
'Bag',
'Ankle boot'
]
print(x_train[0])
[[ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] [ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] [ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] [ 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 13 73 0 0 1 4 0 0 0 0 1 1 0] [ 0 0 0 0 0 0 0 0 0 0 0 0 3 0 36 136 127 62 54 0 0 0 1 3 4 0 0 3] [ 0 0 0 0 0 0 0 0 0 0 0 0 6 0 102 204 176 134 144 123 23 0 0 0 0 12 10 0] [ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 155 236 207 178 107 156 161 109 64 23 77 130 72 15] [ 0 0 0 0 0 0 0 0 0 0 0 1 0 69 207 223 218 216 216 163 127 121 122 146 141 88 172 66] [ 0 0 0 0 0 0 0 0 0 1 1 1 0 200 232 232 233 229 223 223 215 213 164 127 123 196 229 0] [ 0 0 0 0 0 0 0 0 0 0 0 0 0 183 225 216 223 228 235 227 224 222 224 221 223 245 173 0] [ 0 0 0 0 0 0 0 0 0 0 0 0 0 193 228 218 213 198 180 212 210 211 213 223 220 243 202 0] [ 0 0 0 0 0 0 0 0 0 1 3 0 12 219 220 212 218 192 169 227 208 218 224 212 226 197 209 52] [ 0 0 0 0 0 0 0 0 0 0 6 0 99 244 222 220 218 203 198 221 215 213 222 220 245 119 167 56] [ 0 0 0 0 0 0 0 0 0 4 0 0 55 236 228 230 228 240 232 213 218 223 234 217 217 209 92 0] [ 0 0 1 4 6 7 2 0 0 0 0 0 237 226 217 223 222 219 222 221 216 223 229 215 218 255 77 0] [ 0 3 0 0 0 0 0 0 0 62 145 204 228 207 213 221 218 208 211 218 224 223 219 215 224 244 159 0] [ 0 0 0 0 18 44 82 107 189 228 220 222 217 226 200 205 211 230 224 234 176 188 250 248 233 238 215 0] [ 0 57 187 208 224 221 224 208 204 214 208 209 200 159 245 193 206 223 255 255 221 234 221 211 220 232 246 0] [ 3 202 228 224 221 211 211 214 205 205 205 220 240 80 150 255 229 221 188 154 191 210 204 209 222 228 225 0] [ 98 233 198 210 222 229 229 234 249 220 194 215 217 241 65 73 106 117 168 219 221 215 217 223 223 224 229 29] [ 75 204 212 204 193 205 211 225 216 185 197 206 198 213 240 195 227 245 239 223 218 212 209 222 220 221 230 67] [ 48 203 183 194 213 197 185 190 194 192 202 214 219 221 220 236 225 216 199 206 186 181 177 172 181 205 206 115] [ 0 122 219 193 179 171 183 196 204 210 213 207 211 210 200 196 194 191 195 191 198 192 176 156 167 177 210 92] [ 0 0 74 189 212 191 175 172 175 181 185 188 189 188 193 198 204 209 210 210 211 188 188 194 192 216 170 0] [ 2 0 0 0 66 200 222 237 239 242 246 243 244 221 220 193 191 179 182 182 181 176 166 168 99 58 0 0] [ 0 0 0 0 0 0 0 40 61 44 72 41 35 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] [ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] [ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]]
plt.figure(figsize=(2, 2))
plt.imshow(x_train[0], cmap=plt.cm.binary)
plt.show()
numbers_to_display = 25
num_cells = math.ceil(math.sqrt(numbers_to_display))
plt.figure(figsize=(8, 8))
for i in range(numbers_to_display):
plt.subplot(num_cells, num_cells, i+1)
plt.xticks([])
plt.yticks([])
plt.grid(False)
plt.imshow(x_train[i], cmap=plt.cm.binary)
plt.xlabel(class_names[y_train[i]])
plt.show()
x_train_reshaped = x_train.reshape(
x_train.shape[0],
x_train.shape[1],
x_train.shape[2],
1
).astype('float32')
print('x_train_reshaped.shape: ', x_train_reshaped.shape)
x_train_reshaped.shape: (70000, 28, 28, 1)
# Normalize image pixel values to [-1, 1] range
x_train_normalized = (x_train_reshaped - 127.5) / 127.5
print('Normalized data values:\n')
print(x_train_normalized[0,:,:,0])
Normalized data values: [[-1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. ] [-1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. ] [-1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. ] [-1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -0.99215686 -1. -1. -0.8980392 -0.42745098 -1. -1. -0.99215686 -0.96862745 -1. -1. -1. -1. -0.99215686 -0.99215686 -1. ] [-1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -0.9764706 -1. -0.7176471 0.06666667 -0.00392157 -0.5137255 -0.5764706 -1. -1. -1. -0.99215686 -0.9764706 -0.96862745 -1. -1. -0.9764706 ] [-1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -0.9529412 -1. -0.2 0.6 0.38039216 0.05098039 0.12941177 -0.03529412 -0.81960785 -1. -1. -1. -1. -0.90588236 -0.92156863 -1. ] [-1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. 0.21568628 0.8509804 0.62352943 0.39607844 -0.16078432 0.22352941 0.2627451 -0.14509805 -0.49803922 -0.81960785 -0.39607844 0.01960784 -0.43529412 -0.88235295] [-1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -0.99215686 -1. -0.45882353 0.62352943 0.7490196 0.70980394 0.69411767 0.69411767 0.2784314 -0.00392157 -0.05098039 -0.04313726 0.14509805 0.10588235 -0.30980393 0.34901962 -0.48235294] [-1. -1. -1. -1. -1. -1. -1. -1. -1. -0.99215686 -0.99215686 -0.99215686 -1. 0.5686275 0.81960785 0.81960785 0.827451 0.79607844 0.7490196 0.7490196 0.6862745 0.67058825 0.28627452 -0.00392157 -0.03529412 0.5372549 0.79607844 -1. ] [-1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. 0.43529412 0.7647059 0.69411767 0.7490196 0.7882353 0.84313726 0.78039217 0.75686276 0.7411765 0.75686276 0.73333335 0.7490196 0.92156863 0.35686275 -1. ] [-1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. 0.5137255 0.7882353 0.70980394 0.67058825 0.5529412 0.4117647 0.6627451 0.64705884 0.654902 0.67058825 0.7490196 0.7254902 0.90588236 0.58431375 -1. ] [-1. -1. -1. -1. -1. -1. -1. -1. -1. -0.99215686 -0.9764706 -1. -0.90588236 0.7176471 0.7254902 0.6627451 0.70980394 0.5058824 0.3254902 0.78039217 0.6313726 0.70980394 0.75686276 0.6627451 0.77254903 0.54509807 0.6392157 -0.5921569 ] [-1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -0.9529412 -1. -0.22352941 0.9137255 0.7411765 0.7254902 0.70980394 0.5921569 0.5529412 0.73333335 0.6862745 0.67058825 0.7411765 0.7254902 0.92156863 -0.06666667 0.30980393 -0.56078434] [-1. -1. -1. -1. -1. -1. -1. -1. -1. -0.96862745 -1. -1. -0.5686275 0.8509804 0.7882353 0.8039216 0.7882353 0.88235295 0.81960785 0.67058825 0.70980394 0.7490196 0.8352941 0.7019608 0.7019608 0.6392157 -0.2784314 -1. ] [-1. -1. -0.99215686 -0.96862745 -0.9529412 -0.94509804 -0.9843137 -1. -1. -1. -1. -1. 0.85882354 0.77254903 0.7019608 0.7490196 0.7411765 0.7176471 0.7411765 0.73333335 0.69411767 0.7490196 0.79607844 0.6862745 0.70980394 1. -0.39607844 -1. ] [-1. -0.9764706 -1. -1. -1. -1. -1. -1. -1. -0.5137255 0.13725491 0.6 0.7882353 0.62352943 0.67058825 0.73333335 0.70980394 0.6313726 0.654902 0.70980394 0.75686276 0.7490196 0.7176471 0.6862745 0.75686276 0.9137255 0.24705882 -1. ] [-1. -1. -1. -1. -0.85882354 -0.654902 -0.35686275 -0.16078432 0.48235294 0.7882353 0.7254902 0.7411765 0.7019608 0.77254903 0.5686275 0.60784316 0.654902 0.8039216 0.75686276 0.8352941 0.38039216 0.4745098 0.9607843 0.94509804 0.827451 0.8666667 0.6862745 -1. ] [-1. -0.5529412 0.46666667 0.6313726 0.75686276 0.73333335 0.75686276 0.6313726 0.6 0.6784314 0.6313726 0.6392157 0.5686275 0.24705882 0.92156863 0.5137255 0.6156863 0.7490196 1. 1. 0.73333335 0.8352941 0.73333335 0.654902 0.7254902 0.81960785 0.92941177 -1. ] [-0.9764706 0.58431375 0.7882353 0.75686276 0.73333335 0.654902 0.654902 0.6784314 0.60784316 0.60784316 0.60784316 0.7254902 0.88235295 -0.37254903 0.1764706 1. 0.79607844 0.73333335 0.4745098 0.20784314 0.49803922 0.64705884 0.6 0.6392157 0.7411765 0.7882353 0.7647059 -1. ] [-0.23137255 0.827451 0.5529412 0.64705884 0.7411765 0.79607844 0.79607844 0.8352941 0.9529412 0.7254902 0.52156866 0.6862745 0.7019608 0.8901961 -0.49019608 -0.42745098 -0.16862746 -0.08235294 0.31764707 0.7176471 0.73333335 0.6862745 0.7019608 0.7490196 0.7490196 0.75686276 0.79607844 -0.77254903] [-0.4117647 0.6 0.6627451 0.6 0.5137255 0.60784316 0.654902 0.7647059 0.69411767 0.4509804 0.54509807 0.6156863 0.5529412 0.67058825 0.88235295 0.5294118 0.78039217 0.92156863 0.8745098 0.7490196 0.70980394 0.6627451 0.6392157 0.7411765 0.7254902 0.73333335 0.8039216 -0.4745098 ] [-0.62352943 0.5921569 0.43529412 0.52156866 0.67058825 0.54509807 0.4509804 0.49019608 0.52156866 0.5058824 0.58431375 0.6784314 0.7176471 0.73333335 0.7254902 0.8509804 0.7647059 0.69411767 0.56078434 0.6156863 0.45882353 0.41960785 0.3882353 0.34901962 0.41960785 0.60784316 0.6156863 -0.09803922] [-1. -0.04313726 0.7176471 0.5137255 0.40392157 0.34117648 0.43529412 0.5372549 0.6 0.64705884 0.67058825 0.62352943 0.654902 0.64705884 0.5686275 0.5372549 0.52156866 0.49803922 0.5294118 0.49803922 0.5529412 0.5058824 0.38039216 0.22352941 0.30980393 0.3882353 0.64705884 -0.2784314 ] [-1. -1. -0.41960785 0.48235294 0.6627451 0.49803922 0.37254903 0.34901962 0.37254903 0.41960785 0.4509804 0.4745098 0.48235294 0.4745098 0.5137255 0.5529412 0.6 0.6392157 0.64705884 0.64705884 0.654902 0.4745098 0.4745098 0.52156866 0.5058824 0.69411767 0.33333334 -1. ] [-0.9843137 -1. -1. -1. -0.48235294 0.5686275 0.7411765 0.85882354 0.8745098 0.8980392 0.92941177 0.90588236 0.9137255 0.73333335 0.7254902 0.5137255 0.49803922 0.40392157 0.42745098 0.42745098 0.41960785 0.38039216 0.3019608 0.31764707 -0.22352941 -0.54509807 -1. -1. ] [-1. -1. -1. -1. -1. -1. -1. -0.6862745 -0.52156866 -0.654902 -0.43529412 -0.6784314 -0.7254902 -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. ] [-1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. ] [-1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. ]]
SHUFFLE_BUFFER_SIZE = TOTAL_EXAMPLES_NUM
BATCH_SIZE = 1024
TRAINING_STEPS_PER_EPOCH = math.ceil(TOTAL_EXAMPLES_NUM / BATCH_SIZE)
print('BATCH_SIZE: ', BATCH_SIZE)
print('TRAINING_STEPS_PER_EPOCH: ', TRAINING_STEPS_PER_EPOCH)
BATCH_SIZE: 1024 TRAINING_STEPS_PER_EPOCH: 69
train_dataset = tf.data.Dataset.from_tensor_slices(x_train_normalized) \
.shuffle(SHUFFLE_BUFFER_SIZE) \
.batch(BATCH_SIZE)
print(train_dataset)
<BatchDataset shapes: (None, 28, 28, 1), types: tf.float32>
def make_generator_model():
model = tf.keras.Sequential()
# Step 1.
model.add(tf.keras.layers.Dense(
units=7*7*256,
use_bias=False,
input_shape=(100,)
))
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.LeakyReLU())
# Step 2.
model.add(tf.keras.layers.Reshape((7, 7, 256)))
assert model.output_shape == (None, 7, 7, 256) # None is a batch size.
# Step 3.
model.add(tf.keras.layers.Conv2DTranspose(
filters=128,
kernel_size=(5, 5),
strides=(1, 1),
padding='same',
use_bias=False
))
assert model.output_shape == (None, 7, 7, 128)
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.LeakyReLU())
# Step 4.
model.add(tf.keras.layers.Conv2DTranspose(
filters=64,
kernel_size=(5, 5),
strides=(2, 2),
padding='same',
use_bias=False
))
assert model.output_shape == (None, 14, 14, 64)
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.LeakyReLU())
# Step 5.
model.add(tf.keras.layers.Conv2DTranspose(
filters=1,
kernel_size=(5, 5),
strides=(2, 2),
padding='same',
use_bias=False,
activation='tanh'
))
assert model.output_shape == (None, 28, 28, 1)
return model
generator_model = make_generator_model()
generator_model.summary()
Model: "sequential_2" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense_2 (Dense) (None, 12544) 1254400 _________________________________________________________________ batch_normalization_3 (Batch (None, 12544) 50176 _________________________________________________________________ leaky_re_lu_5 (LeakyReLU) (None, 12544) 0 _________________________________________________________________ reshape_1 (Reshape) (None, 7, 7, 256) 0 _________________________________________________________________ conv2d_transpose_3 (Conv2DTr (None, 7, 7, 128) 819200 _________________________________________________________________ batch_normalization_4 (Batch (None, 7, 7, 128) 512 _________________________________________________________________ leaky_re_lu_6 (LeakyReLU) (None, 7, 7, 128) 0 _________________________________________________________________ conv2d_transpose_4 (Conv2DTr (None, 14, 14, 64) 204800 _________________________________________________________________ batch_normalization_5 (Batch (None, 14, 14, 64) 256 _________________________________________________________________ leaky_re_lu_7 (LeakyReLU) (None, 14, 14, 64) 0 _________________________________________________________________ conv2d_transpose_5 (Conv2DTr (None, 28, 28, 1) 1600 ================================================================= Total params: 2,330,944 Trainable params: 2,305,472 Non-trainable params: 25,472 _________________________________________________________________
tf.keras.utils.plot_model(
generator_model,
show_shapes=True,
show_layer_names=True,
to_file='generator_model.png'
)
noise = tf.random.normal(shape=[1, 100])
print(noise.numpy())
[[-0.40048575 -0.05899175 1.4562819 -1.0014365 -0.49212027 -0.02965875 -0.359166 -0.13207589 -0.05057759 -0.7165569 1.348586 -0.5673278 1.2634926 0.69921404 0.1937053 0.6227805 -0.8695679 1.5103983 -0.87547743 -0.9266758 -0.94829696 -0.05095936 -0.33741903 1.7615291 -0.5133022 -1.9463009 -0.47871232 -0.3488199 0.30957273 0.7816733 -0.05789489 -0.65704983 1.328389 -0.17778188 -0.09234082 1.3932896 -0.4493648 1.0459919 -0.27539548 -0.5148557 0.6957658 0.2747286 1.1346655 -0.58931726 1.2654362 -1.3202895 0.05925234 0.57879525 0.10334445 0.12310242 0.9459113 0.46510884 -1.7625321 1.1537969 0.95482457 0.07740594 1.101719 -0.94574606 0.90892 1.0484036 -1.3593727 0.44116554 -1.3168752 -0.6017726 0.22471838 0.524481 -0.14121588 0.20263085 -0.29945704 0.7611715 0.42068037 0.25163245 1.2361112 -0.49910775 0.5315448 -0.29282215 0.44396016 0.18037954 -0.66255087 -0.01424331 0.25285572 0.54898584 0.01249223 -1.8984798 0.34870186 -0.38278985 -1.2752656 -0.22798996 1.2903773 0.5359575 -0.8025157 1.1422137 1.4703174 -0.08264713 0.27506617 1.0309777 -0.48443326 0.76453096 0.57719016 0.5316679 ]]
generated_image = generator_model(noise, training=False)
print('generated_image.shape: ', generated_image.shape)
generated_image.shape: (1, 28, 28, 1)
print(generated_image[0, :, :, 0].numpy())
[[ 3.62146273e-03 -3.37373791e-03 -2.23259837e-03 4.66620969e-03 1.15258023e-04 5.01324749e-03 -5.11144102e-03 -4.98591131e-03 3.25806858e-03 3.24650109e-03 -3.68705858e-03 -7.00298278e-03 7.17507210e-03 3.48935777e-04 -1.13091886e-03 -3.62540688e-03 4.15722467e-03 -5.91829943e-04 4.71155299e-03 -3.61183868e-03 3.84368817e-03 5.68057830e-03 5.30142151e-03 -1.17566239e-03 -1.84537598e-03 -2.84438924e-04 -1.76289608e-03 6.83842576e-04] [-2.32459349e-03 -1.27953419e-03 -3.37371975e-03 -1.02897873e-03 -1.19031165e-02 -3.47901043e-03 1.06568029e-02 1.61387771e-02 -1.69485938e-02 1.10559864e-03 -1.05018653e-02 5.57437493e-03 2.92407884e-03 -4.07211157e-03 5.88989642e-04 2.27185269e-03 1.62952056e-03 9.98135749e-03 6.50358619e-03 5.50584961e-03 -3.35985026e-03 5.92968427e-03 8.01353902e-03 1.04573760e-02 -5.20926062e-03 -3.15492740e-03 -9.45294369e-03 -2.02907156e-03] [-1.08655437e-03 -9.76242532e-04 1.22154327e-02 2.89549213e-03 -2.40238005e-04 -5.81492018e-03 1.95882097e-03 -9.22838750e-04 1.82223914e-03 -7.69637479e-03 -5.21446159e-03 -1.09772468e-02 -5.95692173e-03 -2.33915867e-03 -9.31353495e-03 1.22239972e-02 -6.31380128e-03 1.30199986e-02 1.08500058e-02 -1.09621149e-03 4.78329835e-03 -1.27854310e-02 -4.04617889e-03 1.11324312e-02 -2.68187723e-03 -8.11451860e-03 1.80690826e-04 -3.23777436e-03] [-4.34456859e-03 1.08021591e-03 9.07736551e-03 6.71621133e-03 1.94656372e-03 -1.95164117e-03 -1.32440859e-02 1.99190094e-04 2.16311179e-02 -1.28664607e-02 -7.09737279e-03 -5.89591684e-03 5.80052426e-03 -4.83242143e-03 -1.52255045e-02 1.20152002e-02 1.47292307e-02 9.15968511e-03 1.51714450e-02 1.48223871e-02 1.03338063e-02 -1.48535576e-02 -1.52136211e-03 9.71621554e-03 4.76413919e-03 -6.04430959e-03 5.14708646e-03 -2.87405419e-04] [-3.02930758e-03 -1.19427703e-02 -2.88041658e-03 2.63637368e-04 -2.47954595e-04 -2.68383650e-03 -8.25305842e-03 4.87447809e-03 3.87723232e-03 1.71127636e-02 1.76618434e-02 -3.68945720e-03 -1.29244069e-03 4.67234664e-03 -1.35340896e-02 1.63192730e-02 1.62246879e-02 -3.82160139e-03 -1.26241334e-02 -6.73409458e-03 -2.56166025e-03 -1.48072299e-02 9.48595873e-04 -5.34741674e-03 5.62866312e-03 7.48308841e-03 4.01935400e-03 -3.19045852e-03] [-3.48949130e-03 -9.35084280e-03 -8.24388862e-03 -1.24650507e-03 2.93476880e-03 -2.71880403e-02 -5.54847671e-03 2.81886011e-02 -9.42054112e-03 -6.17643725e-03 -3.36366938e-03 -3.17791626e-02 1.46224005e-02 -8.51377100e-03 -5.54802548e-03 4.99748439e-02 -1.47969567e-03 1.36440508e-02 1.15090255e-02 6.00805460e-03 -2.44427957e-02 -2.54061539e-03 -1.94960111e-03 1.77331921e-02 -1.45183373e-02 -9.58510581e-03 -2.48887925e-03 6.41941139e-03] [ 4.38543037e-03 -2.23167450e-03 1.44272868e-03 -3.01642832e-03 -1.22719235e-03 -7.65156560e-03 -1.68482959e-03 -1.07503235e-02 -1.77560211e-03 -1.05693024e-02 8.33574310e-03 -2.35626772e-02 -1.26337176e-02 -1.56643074e-02 -8.19774810e-04 -5.32189431e-03 -6.26925472e-03 -9.05521493e-03 -1.19840242e-02 4.97127825e-04 -1.48960864e-02 -2.39627250e-02 -5.92306256e-03 1.59224458e-02 -1.52227646e-02 -1.00052310e-02 6.09163009e-03 -2.11464837e-02] [ 2.21529067e-03 1.42816352e-02 5.81281027e-03 -3.23327188e-03 5.38593065e-03 -1.07889320e-03 -2.99738417e-03 2.65213521e-03 -6.27323193e-03 2.66653467e-02 2.00978648e-02 1.15421542e-03 1.33487210e-02 3.74643435e-03 3.68786044e-03 2.31483839e-02 6.19633729e-03 -1.05314497e-02 -2.16631126e-02 3.11499853e-02 -3.53607279e-03 1.51188718e-02 -1.13426456e-02 4.78743464e-02 2.75963675e-02 -4.74607712e-03 2.98028695e-03 -1.47000421e-02] [-9.75484552e-04 -1.87517386e-02 6.39571156e-03 -1.09980404e-02 -6.36178348e-03 5.97111508e-03 -8.37849453e-04 -5.37107652e-03 -1.53775848e-02 -1.13707948e-02 1.30695766e-02 -1.22213075e-02 7.32243853e-03 4.31752158e-03 3.81027418e-03 -1.11757852e-02 1.48883620e-02 5.91330696e-03 6.45949878e-03 -9.61624738e-03 1.23961624e-02 3.12477164e-03 1.78550836e-02 -1.53378481e-02 1.68247763e-02 2.35145888e-03 9.66674462e-03 -1.26689449e-02] [-4.53034788e-03 -5.13586868e-03 -1.09016187e-02 2.11124518e-03 1.03256200e-02 1.92142017e-02 1.01808663e-02 7.77163729e-03 -1.13144163e-02 -2.06057518e-03 1.64466747e-03 -1.26656434e-02 -1.25816762e-02 -2.35923030e-03 -1.24072339e-02 2.99524952e-04 1.59975723e-03 -4.04913984e-02 -1.04828384e-02 2.14716420e-02 -1.75706518e-03 6.97366986e-03 -6.76607294e-03 5.74283442e-03 -9.74454393e-04 3.70335430e-02 3.50344516e-02 -2.19072085e-02] [-6.33765059e-03 -7.24817067e-03 6.53755618e-03 1.95069844e-03 2.49424344e-03 9.76337679e-03 2.74362997e-03 -4.01810836e-03 -6.89734798e-03 7.55376508e-03 4.68930113e-04 -1.00104986e-02 1.61441267e-02 2.13299580e-02 8.37038644e-03 -1.22260777e-02 -5.00570657e-03 -1.10691898e-02 -1.48134297e-02 3.67610437e-06 -7.90168159e-03 -2.45886985e-02 2.16707699e-02 1.30667910e-02 -2.27732006e-02 -2.44000666e-02 4.69451724e-03 -4.65407036e-03] [-3.97500495e-04 5.59741352e-03 -1.26262533e-03 1.38056604e-02 7.69853080e-03 1.02469148e-02 2.11865967e-03 2.34880578e-02 -1.36010703e-02 2.41397936e-02 1.47145968e-02 2.73363153e-03 2.54591051e-02 1.31826168e-02 5.61611541e-03 2.16647629e-02 1.68988332e-02 -5.91650896e-04 -1.88403949e-02 3.47850434e-02 -1.03903115e-02 1.19025661e-02 2.24366654e-02 1.25314258e-02 -9.98511538e-03 -4.03224770e-03 -2.05838848e-02 -3.67625132e-02] [-6.81647705e-03 -6.81502337e-04 -1.02212690e-02 -1.06686605e-02 -1.42854196e-03 6.24759961e-03 -1.24904374e-02 -4.81363852e-03 1.33938845e-02 2.21875357e-03 -1.23913633e-02 3.41113540e-03 1.73668731e-02 -5.56474971e-03 -1.62793212e-02 -2.01927647e-02 -1.38299596e-02 -2.20136046e-02 -1.97079289e-03 -2.30370704e-02 1.23280697e-02 1.55290756e-02 -9.70291439e-03 -4.64379601e-03 6.73847552e-03 2.43831128e-02 3.51902819e-03 -2.45220028e-02] [ 6.70772512e-03 -5.60176698e-03 -1.80406333e-03 -1.35467006e-02 -1.31471567e-02 -1.45529835e-02 9.33896867e-04 -2.44586216e-03 2.89507862e-02 1.62595250e-02 8.47978052e-03 -1.79202054e-02 4.06631036e-03 -4.48405147e-02 -4.93633561e-03 2.67268904e-03 7.71781197e-04 -2.08309945e-02 -2.98636500e-02 -2.63190269e-02 -3.18020917e-02 -8.53830110e-03 -1.12102060e-02 -3.05246729e-02 -2.17090603e-02 -2.07713293e-03 2.57720239e-03 -7.04809674e-04] [-1.30386988e-03 -1.52483920e-03 2.78399838e-03 -6.23396947e-04 -7.02505186e-03 2.66378094e-02 -9.57542099e-03 3.94407572e-04 9.45235137e-03 -3.91205922e-02 -2.20355834e-03 3.01534543e-03 -7.43860006e-03 -2.32779118e-03 -7.99226016e-03 2.78095668e-03 5.68673387e-03 4.55740327e-03 -2.12263670e-02 -1.51697071e-02 -1.48749324e-02 -1.68445744e-02 -1.43898488e-03 -1.12640904e-02 -1.93291102e-02 -2.06826311e-02 -4.39894153e-03 -1.58750862e-02] [ 8.92661221e-04 2.87332553e-02 5.50161488e-03 2.76262425e-02 9.34928306e-04 1.70804355e-02 9.89066157e-03 2.29741875e-02 -2.80773058e-03 2.93803997e-02 -4.72826976e-03 1.22490549e-03 1.87103860e-02 3.96030135e-02 -1.48638757e-02 -6.88467501e-03 4.44180798e-03 9.90787428e-03 1.96425766e-02 2.99594030e-02 1.13659725e-02 -9.46089742e-04 1.61536806e-03 2.01995093e-02 2.76264064e-02 1.35942772e-02 -1.88921001e-02 2.23646569e-03] [ 5.96996909e-03 3.21599364e-04 -2.40466148e-02 2.00503990e-02 -1.41588096e-02 -1.20846624e-03 -3.94202815e-03 -4.52067005e-03 1.48022855e-02 -5.81123924e-04 -7.56723946e-03 -2.56106053e-02 -2.38823914e-03 1.01325503e-02 1.22936796e-02 4.06868234e-02 5.15955733e-03 4.25607292e-03 -1.10634714e-02 -2.30548475e-02 -3.66645493e-03 3.35261784e-03 -7.60676712e-03 -8.83055013e-03 -5.17971860e-03 -1.00199468e-02 7.56571535e-03 -3.23796133e-03] [-1.83949678e-03 -8.92570149e-03 -1.32302679e-02 -9.34507325e-03 8.84845294e-03 8.96911044e-03 6.63216552e-03 1.58976465e-02 -6.10901276e-03 9.65222716e-04 -5.26000001e-03 2.36379020e-02 1.09209521e-02 1.68258753e-02 2.02850197e-02 -2.11746078e-02 3.50818895e-02 4.76629473e-03 3.85219534e-03 -7.08658341e-03 -1.38495686e-02 1.56487878e-02 -7.85811152e-03 2.36311066e-03 -2.31542159e-02 5.98499179e-03 -2.21874677e-02 7.89724570e-03] [ 2.49292911e-03 -1.22126425e-03 2.48686783e-02 5.92503930e-04 -1.04724327e-02 -1.32915620e-02 1.57627966e-02 -2.26045437e-02 -1.77267920e-02 -4.18708995e-02 -1.75703075e-02 -1.85578279e-02 -1.01273982e-02 -1.36248032e-02 9.71704908e-03 -2.17672940e-02 -2.45038588e-02 -8.22731759e-03 -3.20212208e-02 1.86296720e-02 -2.20870040e-03 -6.29124790e-03 -3.48552503e-03 1.67474449e-02 4.64327773e-03 6.60257973e-03 5.24758128e-03 -1.41102597e-02] [-3.59580526e-03 2.54093725e-02 -4.23988618e-04 1.20853363e-02 1.81288517e-03 1.16510382e-02 1.78205734e-03 -3.17027746e-03 -5.59030659e-03 8.83010961e-03 -4.48579621e-03 5.38882939e-03 2.30905674e-02 1.38789257e-02 2.89347838e-03 5.11234975e-04 8.46187305e-03 -6.83443667e-03 2.63330564e-02 5.61695285e-02 2.83867344e-02 2.13272609e-02 1.86818764e-02 2.40119528e-02 4.12550941e-02 1.71269625e-02 -1.16232736e-02 -1.12093156e-02] [-5.26016229e-04 -8.02509487e-03 -1.78317726e-03 1.26240542e-02 8.84521101e-03 -7.15629384e-03 4.48989728e-03 1.50084337e-02 3.44703579e-03 2.97801523e-03 1.95164531e-02 2.66057410e-04 8.63145944e-03 -1.68882143e-02 7.61054363e-03 1.93403792e-02 -9.89504531e-03 -2.29283813e-02 -4.38777643e-04 1.24670612e-02 1.84697332e-03 1.63523685e-02 -2.33178455e-02 -4.12332732e-03 -2.49453937e-03 4.16569673e-02 -6.18983945e-03 3.15394904e-03] [ 2.73987162e-03 -5.72156301e-03 4.67017666e-03 -3.71920830e-03 -1.73664801e-02 -8.21895618e-03 -9.91616305e-03 6.70508016e-03 -6.69953274e-03 -1.69223372e-03 1.84470660e-03 3.59362410e-03 1.25817265e-02 -2.89662275e-02 2.10328940e-02 -4.60524764e-03 1.11880424e-02 -2.16246583e-03 9.07421485e-03 4.56262659e-03 1.82753647e-04 -3.30339260e-02 8.84283800e-03 -4.15881164e-03 -1.63521711e-02 -3.97552326e-02 1.58415325e-02 3.19608976e-03] [ 2.95576709e-03 -1.34218726e-02 2.30270363e-02 -9.30730719e-03 -2.18246076e-02 -1.60645489e-02 5.29193785e-03 7.57779853e-05 -2.87380698e-03 1.59277150e-03 9.97675862e-03 -1.49989566e-02 6.64960174e-03 -3.60953957e-02 1.10111600e-02 8.14057013e-04 -7.30547169e-03 -2.98824403e-02 -1.86987207e-04 -1.52166411e-02 -1.55169964e-02 -3.13213188e-03 -1.91632484e-03 -7.88594875e-03 6.13504043e-03 -2.98424568e-02 -8.94748140e-03 -2.02390347e-02] [ 1.76868285e-03 1.23285577e-02 -3.77325248e-03 1.73113104e-02 -3.51235792e-02 2.40283497e-02 -3.46030295e-03 3.47760394e-02 1.39564648e-03 2.72532441e-02 -1.16243036e-02 -5.48969116e-03 -1.43558974e-03 3.63925993e-02 2.29507256e-02 2.23198514e-02 8.72469880e-03 4.75151688e-02 1.37448986e-03 4.37433040e-03 2.42167786e-02 -1.66251697e-02 -1.14183633e-04 6.21859496e-03 3.36538441e-02 -1.27480784e-02 -1.96997309e-03 5.11262275e-04] [ 3.15447431e-03 -9.60949156e-03 2.53796461e-03 1.22282030e-02 -3.31321498e-03 1.04366830e-02 9.54547897e-03 1.72090158e-03 -5.68748277e-04 8.43942817e-03 5.74815262e-04 3.61018814e-02 1.43055217e-02 1.49328273e-03 -1.61119234e-02 9.13903210e-03 3.24818306e-02 -1.55513333e-02 4.50196257e-03 -7.38349790e-03 3.01722754e-02 -3.80781991e-03 1.55264568e-02 -4.45965398e-03 -5.50039054e-04 -4.53713071e-03 -4.82875621e-03 -1.75618660e-02] [-1.39705255e-03 -9.56590381e-03 -1.24372635e-02 -7.42626889e-03 -1.31014327e-03 1.31019633e-02 -2.14343797e-02 -1.36958398e-02 -2.04797834e-02 -2.22509820e-02 -1.76416654e-02 2.37367339e-02 1.06160017e-02 1.56296641e-02 1.81274815e-03 -1.02035957e-03 -2.93547707e-03 -1.68886241e-02 -1.78066101e-02 2.42912993e-02 1.25687057e-02 -5.75363776e-03 -3.67524044e-04 1.31052989e-03 2.40805000e-03 -2.77734222e-03 8.07198603e-03 1.07840775e-02] [-2.46818783e-03 5.46638342e-03 -2.33211252e-03 3.60420416e-03 -4.41368949e-03 1.26605816e-02 1.38953701e-03 -2.76276818e-03 -4.74420562e-03 -3.86390765e-03 -1.23261986e-02 -1.77381169e-02 -1.25269650e-03 -1.33819750e-03 1.38756512e-02 -1.65214879e-03 -1.91068149e-03 -1.06625482e-02 5.72940707e-03 2.26340070e-03 8.95494036e-03 -4.09065047e-03 1.84036735e-02 -1.92423984e-02 -1.22753307e-02 -1.53727038e-02 -3.91660351e-03 -3.69235151e-03] [ 1.52455550e-03 4.71905223e-04 2.51185521e-03 3.77575343e-04 3.94193502e-03 -2.46343319e-03 1.21627003e-03 1.43145500e-02 2.26958841e-02 -2.91464501e-03 -2.12998893e-02 -1.22428909e-02 -4.66084020e-04 9.76727158e-03 -1.22103025e-03 2.02955268e-02 1.04613770e-02 8.11114442e-03 -4.82826121e-03 2.12394483e-02 2.55176052e-02 -1.36667909e-02 -1.34665845e-02 6.25952985e-03 7.91868195e-03 -2.37570284e-03 -1.96220633e-02 2.20343564e-03]]
plt.figure(figsize=(2, 2))
plt.imshow(generated_image[0, :, :, 0], cmap=plt.cm.binary)
<matplotlib.image.AxesImage at 0x1315d7450>
The model will be trained to output positive values for real images, and negative values for fake images.
def make_discriminator_model():
model = tf.keras.Sequential()
# Step 1.
model.add(tf.keras.layers.Conv2D(
filters=64,
kernel_size=(5, 5),
strides=(2, 2),
padding='same',
input_shape=[28, 28, 1]
))
model.add(tf.keras.layers.LeakyReLU())
model.add(tf.keras.layers.Dropout(0.3))
# Step 2.
model.add(tf.keras.layers.Conv2D(
filters=128,
kernel_size=(5, 5),
strides=(2, 2),
padding='same'
))
model.add(tf.keras.layers.LeakyReLU())
model.add(tf.keras.layers.Dropout(0.3))
# Step 3.
model.add(tf.keras.layers.Flatten())
# Real vs Fake
model.add(tf.keras.layers.Dense(1))
return model
discriminator_model = make_discriminator_model()
discriminator_model.summary()
Model: "sequential_4" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= conv2d_4 (Conv2D) (None, 14, 14, 64) 1664 _________________________________________________________________ leaky_re_lu_10 (LeakyReLU) (None, 14, 14, 64) 0 _________________________________________________________________ dropout_4 (Dropout) (None, 14, 14, 64) 0 _________________________________________________________________ conv2d_5 (Conv2D) (None, 7, 7, 128) 204928 _________________________________________________________________ leaky_re_lu_11 (LeakyReLU) (None, 7, 7, 128) 0 _________________________________________________________________ dropout_5 (Dropout) (None, 7, 7, 128) 0 _________________________________________________________________ flatten_2 (Flatten) (None, 6272) 0 _________________________________________________________________ dense_4 (Dense) (None, 1) 6273 ================================================================= Total params: 212,865 Trainable params: 212,865 Non-trainable params: 0 _________________________________________________________________
tf.keras.utils.plot_model(
discriminator_model,
show_shapes=True,
show_layer_names=True,
to_file='discriminator_model.png'
)
dicision = discriminator_model(generated_image)
print(dicision)
tf.Tensor([[0.00100495]], shape=(1, 1), dtype=float32)
This method quantifies how well the discriminator is able to distinguish real images from fakes. It compares the discriminator's predictions on real images to an array of 1s, and the discriminator's predictions on fake (generated) images to an array of 0s.
def discriminator_loss(real_output, fake_output):
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
real_loss = cross_entropy(tf.ones_like(real_output), real_output)
fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
total_loss = real_loss + fake_loss
return total_loss
# Test discriminator loss function.
output_combinations = [
# REAL #FAKE
([-1.], [1.]),
([1.], [-1.]),
([1.], [0.]),
([10.], [-1.]),
]
for (real_output, fake_output) in output_combinations:
loss = discriminator_loss(real_output, fake_output).numpy()
print('Discriminator loss for:', real_output, fake_output)
print(' REAL output:', real_output)
print(' FAKE output:', fake_output)
print(' loss: ', loss)
print()
Discriminator loss for: [-1.0] [1.0] REAL output: [-1.0] FAKE output: [1.0] loss: 2.6265235 Discriminator loss for: [1.0] [-1.0] REAL output: [1.0] FAKE output: [-1.0] loss: 0.6265234 Discriminator loss for: [1.0] [0.0] REAL output: [1.0] FAKE output: [0.0] loss: 1.0064089 Discriminator loss for: [10.0] [-1.0] REAL output: [10.0] FAKE output: [-1.0] loss: 0.31330708
The generator's loss quantifies how well it was able to trick the discriminator. Intuitively, if the generator is performing well, the discriminator will classify the fake images as real (or 1). Here, we will compare the discriminators decisions on the generated images to an array of 1s.
def generator_loss(fake_output):
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
loss = cross_entropy(tf.ones_like(fake_output), fake_output)
return loss
# Test generator loss function.
print('Generator loss for >1: ', generator_loss([5.]).numpy())
print('Generator loss for =0: ', generator_loss([0.]).numpy())
Generator loss for >1: 0.0067153485 Generator loss for =0: 0.6931472
generator_optimizer = tf.keras.optimizers.Adam(
learning_rate=0.0001
)
discriminator_optimizer = tf.keras.optimizers.Adam(
learning_rate=0.0001
)
checkpoint_dir = './tmp/ckpt'
checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt')
checkpoint = tf.train.Checkpoint(
generator_optimizer=generator_optimizer,
discriminator_optimizer=discriminator_optimizer,
generator_model=generator_model,
discriminator_model=discriminator_model
)
EPOCHS = 100
noise_dim = 100
num_examples_to_generate = 16
# We will reuse this seed overtime (so it's easier)
# to visualize progress in the animated GIF)
input_noise_seed = tf.random.normal([num_examples_to_generate, noise_dim])
The training loop begins with generator receiving a random seed as input. That seed is used to produce an image. The discriminator is then used to classify real images (drawn from the training set) and fakes images (produced by the generator). The loss is calculated for each of these models, and the gradients are used to update the generator and discriminator.
# This `tf.function` annotation causes the function to be "compiled".
# @tf.function
def train_step(real_images):
training_history = {
'discriminator': {
'loss': None
},
'generator': {
'loss': None
}
}
# Generate input noise.
noise_images = tf.random.normal([BATCH_SIZE, noise_dim])
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
# Generate fake images.
generated_images = generator_model(
noise_images,
training=True
)
# Detect fake and real images.
real_output = discriminator_model(
real_images,
training=True
)
fake_output = discriminator_model(
generated_images,
training=True
)
# Calculate losses.
gen_loss = generator_loss(fake_output)
disc_loss = discriminator_loss(
real_output,
fake_output
)
training_history['discriminator']['loss'] = disc_loss.numpy()
training_history['generator']['loss'] = gen_loss.numpy()
# Calculate gradients.
gradients_of_generator = gen_tape.gradient(
gen_loss,
generator_model.trainable_variables
)
gradients_of_discriminator = disc_tape.gradient(
disc_loss,
discriminator_model.trainable_variables
)
# Do gradient step.
generator_optimizer.apply_gradients(zip(
gradients_of_generator,
generator_model.trainable_variables
))
discriminator_optimizer.apply_gradients(zip(
gradients_of_discriminator,
discriminator_model.trainable_variables
))
return training_history
def train(dataset, epochs, start_epoch=0):
print('Start training...')
training_history = {
'discriminator': {
'loss': []
},
'generator': {
'loss': []
}
}
for epoch in range(epochs)[start_epoch:]:
print('Start epoch #{} ({} steps)...'.format(epoch + 1, TRAINING_STEPS_PER_EPOCH))
start = time.time()
step = 0
for image_batch in dataset:
step += 1
# display.clear_output(wait=True)
# show_progress(step, TRAINING_STEPS_PER_EPOCH)
# generate_and_save_images(
# generator_model,
# epoch + 1,
# input_noise_seed,
# save=False
# )
training_step_history = train_step(image_batch)
discriminator_step_loss = training_step_history['discriminator']['loss']
generator_step_loss = training_step_history['generator']['loss']
training_history['discriminator']['loss'].append(discriminator_step_loss)
training_history['generator']['loss'].append(generator_step_loss)
# Produce images for the GIF as we go.
display.clear_output(wait=True)
generate_and_save_images(
generator_model,
epoch + 1,
input_noise_seed
)
# Save the model every 10 epochs.
if (epoch + 1) % 10 == 0:
checkpoint.save(file_prefix=checkpoint_prefix)
print('Time for epoch #{} is {:.2f}s'.format(epoch + 1, time.time() - start))
print('Discriminator loss: {:.4f}'.format(discriminator_step_loss))
print('Generator loss: {:.4f}'.format(generator_step_loss))
return training_history
def show_progress(current_step, total_steps):
length_divider = 2
progress = math.floor(current_step * 100 / total_steps)
done_steps = progress
left_steps = 100 - done_steps
done_dots = ''.join(['◼︎' for step in range(math.floor(done_steps / length_divider))])
left_dors = ''.join(['・' for step in range(math.floor(left_steps / length_divider))])
print(f'{current_step}/{total_steps}: {done_dots}{left_dors}')
# Test progress function.
show_progress(15, 68)
15/68: ◼︎◼︎◼︎◼︎◼︎◼︎◼︎◼︎◼︎◼︎◼︎・・・・・・・・・・・・・・・・・・・・・・・・・・・・・・・・・・・・・・・
IMAGES_PREVIEW_PATH = 'tmp/imgs/'
if not os.path.exists(IMAGES_PREVIEW_PATH):
os.makedirs(IMAGES_PREVIEW_PATH)
def generate_and_save_images(model, epoch, test_input, save=True):
# Notice `training` is set to False.
# This is so all layers run in inference mode (batchnorm).
predictions = model(test_input, training=False)
fig_dimension = int(math.sqrt(num_examples_to_generate))
plt.figure(figsize=(8, 8))
fig = plt.figure(figsize=(fig_dimension, fig_dimension))
for i in range(predictions.shape[0]):
plt.subplot(fig_dimension, fig_dimension, i+1)
plt.imshow(
predictions[i, :, :, 0] * 127.5 + 127.5,
cmap=plt.cm.binary
)
plt.axis('off')
if save:
plt.savefig('{}image_at_epoch_{:04d}.png'.format(IMAGES_PREVIEW_PATH, epoch))
plt.show()
if not 'training_history' in locals():
training_history = {
'discriminator': {
'loss': []
},
'generator': {
'loss': []
}
}
training_session_num = 1
start_epoch = training_session_num * EPOCHS
epochs_num = start_epoch + EPOCHS
training_history_current = train(
train_dataset,
epochs=epochs_num,
start_epoch=start_epoch
)
<Figure size 576x576 with 0 Axes>
Time for epoch #2700 is 10.57s Discriminator loss: 1.2439 Generator loss: 0.9003
training_history['generator']['loss'] += training_history_current['generator']['loss']
training_history['discriminator']['loss'] += training_history_current['discriminator']['loss']
def render_training_history(training_history):
generator_loss = training_history['generator']['loss']
discriminator_loss = training_history['discriminator']['loss']
plt.figure(figsize=(14, 4))
plt.subplot(1, 2, 1)
plt.title('Generator Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.plot(generator_loss, label='Generator loss')
plt.legend()
plt.grid(linestyle='--', linewidth=1, alpha=0.5)
plt.subplot(1, 2, 2)
plt.title('Discriminator Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.plot(discriminator_loss, label='Discriminator loss')
plt.legend()
plt.grid(linestyle='--', linewidth=1, alpha=0.5)
plt.show()
render_training_history(training_history)
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
generator_model.save('generator_model.h5', save_format='h5')
discriminator_model.save('discriminator_model.h5', save_format='h5')
# Restore models from files if needed.
# generator_model.load_weights('./generator_model.h5')
# discriminator_model.load_weights('./discriminator_model.h5')
def zip_image_previews():
images_previews_path = IMAGES_PREVIEW_PATH
images_previews_zip_name = 'images_previews.zip'
zipped_files_num = 0
with zipfile.ZipFile(images_previews_zip_name, mode='w') as zip_obj:
for folder_name, subfolders, filenames in os.walk(images_previews_path):
for filename in filenames:
zipped_files_num += 1
file_path = os.path.join(folder_name, filename)
zip_obj.write(file_path, os.path.basename(file_path))
print('Zipped {} files to '.format(zipped_files_num), images_previews_zip_name)
zip_image_previews()
Zipped 2600 files to images_previews.zip
test_examples_num = 10
noise_images = tf.random.normal([test_examples_num, 100])
generated_images = generator_model(noise_images, training=False)
for example_num in range(test_examples_num):
plt.figure(figsize=(3, 3))
plt.subplot(1, 2, 1)
plt.imshow(np.reshape(noise_images[example_num], (10, 10)), cmap=plt.cm.binary)
plt.subplot(1, 2, 2)
plt.imshow(generated_images[example_num, :, :, 0], cmap=plt.cm.binary)
# Display a single image using the epoch number
def display_image(epoch_no):
return PIL.Image.open('{}image_at_epoch_{:04d}.png'.format(IMAGES_PREVIEW_PATH, epoch_no))
display_image(EPOCHS)
anim_file = 'clothes_generation_dcgan.gif'
with imageio.get_writer(anim_file, mode='I') as writer:
filenames = glob.glob(IMAGES_PREVIEW_PATH + 'image*.png')
filenames = sorted(filenames)
last = -1
for i, filename in enumerate(filenames):
frame = 2*(i**0.5)
if round(frame) > round(last):
last = frame
else:
continue
image = imageio.imread(filename)
writer.append_data(image)
image = imageio.imread(filename)
writer.append_data(image)
display.Image(filename=anim_file)
<IPython.core.display.Image object>
To use this model on the web we need to convert it into the format that will be understandable by tensorflowjs. To do so we may use tfjs-converter as following:
tensorflowjs_converter --input_format keras \
./experiments/clothes_generation_dcgan/generator_model.h5 \
./demos/public/models/clothes_generation_dcgan
You find this experiment in the Demo app and play around with it right in you browser to see how the model performs in real life.