Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.
%load_ext watermark
%watermark -a 'Sebastian Raschka' -v -p tensorflow
Sebastian Raschka CPython 3.7.3 IPython 7.6.1 tensorflow 1.13.1
Implementation of General Adversarial Nets (GAN) where both the discriminator and generator have convolutional and deconvolutional layers, respectively. In this example, the GAN generator was trained to generate MNIST images.
Uses
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import pickle as pkl
tf.test.gpu_device_name()
'/device:GPU:0'
### Abbreviatiuons
# dis_*: discriminator network
# gen_*: generator network
########################
### Helper functions
########################
def leaky_relu(x, alpha=0.0001):
return tf.maximum(alpha * x, x)
########################
### DATASET
########################
mnist = input_data.read_data_sets('MNIST_data')
#########################
### SETTINGS
#########################
# Hyperparameters
learning_rate = 0.001
training_epochs = 50
batch_size = 64
dropout_rate = 0.5
# Architecture
dis_input_size = 784
gen_input_size = 100
# Other settings
print_interval = 200
#########################
### GRAPH DEFINITION
#########################
g = tf.Graph()
with g.as_default():
# Placeholders for settings
dropout = tf.placeholder(tf.float32, shape=None, name='dropout')
is_training = tf.placeholder(tf.bool, shape=None, name='is_training')
# Input data
dis_x = tf.placeholder(tf.float32, shape=[None, dis_input_size],
name='discriminator_inputs')
gen_x = tf.placeholder(tf.float32, [None, gen_input_size],
name='generator_inputs')
##################
# Generator Model
##################
with tf.variable_scope('generator'):
# 100 => 784 => 7x7x64
gen_fc = tf.layers.dense(inputs=gen_x, units=3136,
bias_initializer=None, # no bias required when using batch_norm
activation=None)
gen_fc = tf.layers.batch_normalization(gen_fc, training=is_training)
gen_fc = leaky_relu(gen_fc)
gen_fc = tf.reshape(gen_fc, (-1, 7, 7, 64))
# 7x7x64 => 14x14x32
deconv1 = tf.layers.conv2d_transpose(gen_fc, filters=32,
kernel_size=(3, 3), strides=(2, 2),
padding='same',
bias_initializer=None,
activation=None)
deconv1 = tf.layers.batch_normalization(deconv1, training=is_training)
deconv1 = leaky_relu(deconv1)
deconv1 = tf.layers.dropout(deconv1, rate=dropout_rate)
# 14x14x32 => 28x28x16
deconv2 = tf.layers.conv2d_transpose(deconv1, filters=16,
kernel_size=(3, 3), strides=(2, 2),
padding='same',
bias_initializer=None,
activation=None)
deconv2 = tf.layers.batch_normalization(deconv2, training=is_training)
deconv2 = leaky_relu(deconv2)
deconv2 = tf.layers.dropout(deconv2, rate=dropout_rate)
# 28x28x16 => 28x28x8
deconv3 = tf.layers.conv2d_transpose(deconv2, filters=8,
kernel_size=(3, 3), strides=(1, 1),
padding='same',
bias_initializer=None,
activation=None)
deconv3 = tf.layers.batch_normalization(deconv3, training=is_training)
deconv3 = leaky_relu(deconv3)
deconv3 = tf.layers.dropout(deconv3, rate=dropout_rate)
# 28x28x8 => 28x28x1
gen_logits = tf.layers.conv2d_transpose(deconv3, filters=1,
kernel_size=(3, 3), strides=(1, 1),
padding='same',
bias_initializer=None,
activation=None)
gen_out = tf.tanh(gen_logits, 'generator_outputs')
######################
# Discriminator Model
######################
def build_discriminator_graph(input_x, reuse=None):
with tf.variable_scope('discriminator', reuse=reuse):
# 28x28x1 => 14x14x8
conv_input = tf.reshape(input_x, (-1, 28, 28, 1))
conv1 = tf.layers.conv2d(conv_input, filters=8, kernel_size=(3, 3),
strides=(2, 2), padding='same',
bias_initializer=None,
activation=None)
conv1 = tf.layers.batch_normalization(conv1, training=is_training)
conv1 = leaky_relu(conv1)
conv1 = tf.layers.dropout(conv1, rate=dropout_rate)
# 14x14x8 => 7x7x32
conv2 = tf.layers.conv2d(conv1, filters=32, kernel_size=(3, 3),
strides=(2, 2), padding='same',
bias_initializer=None,
activation=None)
conv2 = tf.layers.batch_normalization(conv2, training=is_training)
conv2 = leaky_relu(conv2)
conv2 = tf.layers.dropout(conv2, rate=dropout_rate)
# fully connected layer
fc_input = tf.reshape(conv2, (-1, 7*7*32))
logits = tf.layers.dense(inputs=fc_input, units=1, activation=None)
out = tf.sigmoid(logits)
return logits, out
# Create a discriminator for real data and a discriminator for fake data
dis_real_logits, dis_real_out = build_discriminator_graph(dis_x, reuse=False)
dis_fake_logits, dis_fake_out = build_discriminator_graph(gen_out, reuse=True)
#####################################
# Generator and Discriminator Losses
#####################################
# Two discriminator cost components: loss on real data + loss on fake data
# Real data has class label 1, fake data has class label 0
dis_real_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=dis_real_logits,
labels=tf.ones_like(dis_real_logits))
dis_fake_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=dis_fake_logits,
labels=tf.zeros_like(dis_fake_logits))
dis_cost = tf.add(tf.reduce_mean(dis_fake_loss),
tf.reduce_mean(dis_real_loss),
name='discriminator_cost')
# Generator cost: difference between dis. prediction and label "1" for real images
gen_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=dis_fake_logits,
labels=tf.ones_like(dis_fake_logits))
gen_cost = tf.reduce_mean(gen_loss, name='generator_cost')
#########################################
# Generator and Discriminator Optimizers
#########################################
dis_optimizer = tf.train.AdamOptimizer(learning_rate)
dis_train_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='discriminator')
dis_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='discriminator')
with tf.control_dependencies(dis_update_ops): # required to upd. batch_norm params
dis_train = dis_optimizer.minimize(dis_cost, var_list=dis_train_vars,
name='train_discriminator')
gen_optimizer = tf.train.AdamOptimizer(learning_rate)
gen_train_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='generator')
gen_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='generator')
with tf.control_dependencies(gen_update_ops): # required to upd. batch_norm params
gen_train = gen_optimizer.minimize(gen_cost, var_list=gen_train_vars,
name='train_generator')
# Saver to save session for reuse
saver = tf.train.Saver()
WARNING:tensorflow:From <ipython-input-3-c57ae97e26b0>:17: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version. Instructions for updating: Please use alternatives such as official/mnist/dataset.py from tensorflow/models. WARNING:tensorflow:From /home/raschka/miniconda3/lib/python3.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version. Instructions for updating: Please write your own downloading logic. WARNING:tensorflow:From /home/raschka/miniconda3/lib/python3.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version. Instructions for updating: Please use tf.data to implement this functionality. Extracting MNIST_data/train-images-idx3-ubyte.gz WARNING:tensorflow:From /home/raschka/miniconda3/lib/python3.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version. Instructions for updating: Please use tf.data to implement this functionality. Extracting MNIST_data/train-labels-idx1-ubyte.gz Extracting MNIST_data/t10k-images-idx3-ubyte.gz Extracting MNIST_data/t10k-labels-idx1-ubyte.gz WARNING:tensorflow:From /home/raschka/miniconda3/lib/python3.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version. Instructions for updating: Please use alternatives such as official/mnist/dataset.py from tensorflow/models. WARNING:tensorflow:From <ipython-input-3-c57ae97e26b0>:64: dense (from tensorflow.python.layers.core) is deprecated and will be removed in a future version. Instructions for updating: Use keras.layers.dense instead. WARNING:tensorflow:From /home/raschka/miniconda3/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version. Instructions for updating: Colocations handled automatically by placer. WARNING:tensorflow:From <ipython-input-3-c57ae97e26b0>:65: batch_normalization (from tensorflow.python.layers.normalization) is deprecated and will be removed in a future version. Instructions for updating: Use keras.layers.batch_normalization instead. WARNING:tensorflow:From <ipython-input-3-c57ae97e26b0>:74: conv2d_transpose (from tensorflow.python.layers.convolutional) is deprecated and will be removed in a future version. Instructions for updating: Use keras.layers.conv2d_transpose instead. WARNING:tensorflow:From <ipython-input-3-c57ae97e26b0>:77: dropout (from tensorflow.python.layers.core) is deprecated and will be removed in a future version. Instructions for updating: Use keras.layers.dropout instead. WARNING:tensorflow:From <ipython-input-3-c57ae97e26b0>:121: conv2d (from tensorflow.python.layers.convolutional) is deprecated and will be removed in a future version. Instructions for updating: Use keras.layers.conv2d instead. WARNING:tensorflow:From /home/raschka/miniconda3/lib/python3.7/site-packages/tensorflow/python/ops/math_ops.py:3066: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version. Instructions for updating: Use tf.cast instead.
##########################
### TRAINING & EVALUATION
##########################
with tf.Session(graph=g) as sess:
sess.run(tf.global_variables_initializer())
avg_costs = {'discriminator': [], 'generator': []}
for epoch in range(training_epochs):
dis_avg_cost, gen_avg_cost = 0., 0.
total_batch = mnist.train.num_examples // batch_size
for i in range(total_batch):
batch_x, batch_y = mnist.train.next_batch(batch_size)
batch_x = batch_x*2 - 1 # normalize
batch_randsample = np.random.uniform(-1, 1, size=(batch_size, gen_input_size))
# Train
_, dc = sess.run(['train_discriminator', 'discriminator_cost:0'],
feed_dict={'discriminator_inputs:0': batch_x,
'generator_inputs:0': batch_randsample,
'dropout:0': dropout_rate,
'is_training:0': True})
_, gc = sess.run(['train_generator', 'generator_cost:0'],
feed_dict={'generator_inputs:0': batch_randsample,
'dropout:0': dropout_rate,
'is_training:0': True})
dis_avg_cost += dc
gen_avg_cost += gc
if not i % print_interval:
print("Minibatch: %04d | Dis/Gen Cost: %.3f/%.3f" % (i + 1, dc, gc))
print("Epoch: %04d | Dis/Gen AvgCost: %.3f/%.3f" %
(epoch + 1, dis_avg_cost / total_batch, gen_avg_cost / total_batch))
avg_costs['discriminator'].append(dis_avg_cost / total_batch)
avg_costs['generator'].append(gen_avg_cost / total_batch)
saver.save(sess, save_path='./gan-conv.ckpt')
Minibatch: 0001 | Dis/Gen Cost: 1.630/0.866 Minibatch: 0201 | Dis/Gen Cost: 0.850/1.879 Minibatch: 0401 | Dis/Gen Cost: 0.606/2.467 Minibatch: 0601 | Dis/Gen Cost: 0.695/1.661 Minibatch: 0801 | Dis/Gen Cost: 1.149/1.297 Epoch: 0001 | Dis/Gen AvgCost: 0.820/1.887 Minibatch: 0001 | Dis/Gen Cost: 0.707/1.486 Minibatch: 0201 | Dis/Gen Cost: 0.924/1.438 Minibatch: 0401 | Dis/Gen Cost: 0.751/1.508 Minibatch: 0601 | Dis/Gen Cost: 0.899/1.611 Minibatch: 0801 | Dis/Gen Cost: 0.914/1.535 Epoch: 0002 | Dis/Gen AvgCost: 0.954/1.510 Minibatch: 0001 | Dis/Gen Cost: 0.498/1.955 Minibatch: 0201 | Dis/Gen Cost: 0.757/1.670 Minibatch: 0401 | Dis/Gen Cost: 1.100/1.204 Minibatch: 0601 | Dis/Gen Cost: 0.656/2.054 Minibatch: 0801 | Dis/Gen Cost: 1.036/1.174 Epoch: 0003 | Dis/Gen AvgCost: 0.784/1.720 Minibatch: 0001 | Dis/Gen Cost: 1.576/0.992 Minibatch: 0201 | Dis/Gen Cost: 0.663/2.002 Minibatch: 0401 | Dis/Gen Cost: 0.869/1.773 Minibatch: 0601 | Dis/Gen Cost: 0.675/1.772 Minibatch: 0801 | Dis/Gen Cost: 0.881/1.489 Epoch: 0004 | Dis/Gen AvgCost: 0.898/1.575 Minibatch: 0001 | Dis/Gen Cost: 1.201/1.386 Minibatch: 0201 | Dis/Gen Cost: 1.245/1.606 Minibatch: 0401 | Dis/Gen Cost: 1.281/1.015 Minibatch: 0601 | Dis/Gen Cost: 0.925/1.124 Minibatch: 0801 | Dis/Gen Cost: 1.126/1.634 Epoch: 0005 | Dis/Gen AvgCost: 1.037/1.435 Minibatch: 0001 | Dis/Gen Cost: 0.853/1.626 Minibatch: 0201 | Dis/Gen Cost: 1.204/0.929 Minibatch: 0401 | Dis/Gen Cost: 1.070/1.365 Minibatch: 0601 | Dis/Gen Cost: 1.366/0.927 Minibatch: 0801 | Dis/Gen Cost: 1.253/1.500 Epoch: 0006 | Dis/Gen AvgCost: 1.168/1.186 Minibatch: 0001 | Dis/Gen Cost: 1.590/0.945 Minibatch: 0201 | Dis/Gen Cost: 0.822/1.563 Minibatch: 0401 | Dis/Gen Cost: 0.894/1.410 Minibatch: 0601 | Dis/Gen Cost: 1.292/1.131 Minibatch: 0801 | Dis/Gen Cost: 1.361/1.005 Epoch: 0007 | Dis/Gen AvgCost: 1.248/1.103 Minibatch: 0001 | Dis/Gen Cost: 1.860/0.697 Minibatch: 0201 | Dis/Gen Cost: 1.291/0.986 Minibatch: 0401 | Dis/Gen Cost: 1.097/0.934 Minibatch: 0601 | Dis/Gen Cost: 1.316/0.788 Minibatch: 0801 | Dis/Gen Cost: 1.437/0.885 Epoch: 0008 | Dis/Gen AvgCost: 1.298/0.995 Minibatch: 0001 | Dis/Gen Cost: 1.150/1.072 Minibatch: 0201 | Dis/Gen Cost: 1.177/1.148 Minibatch: 0401 | Dis/Gen Cost: 1.351/0.884 Minibatch: 0601 | Dis/Gen Cost: 1.434/0.797 Minibatch: 0801 | Dis/Gen Cost: 1.291/0.929 Epoch: 0009 | Dis/Gen AvgCost: 1.333/0.968 Minibatch: 0001 | Dis/Gen Cost: 1.324/0.764 Minibatch: 0201 | Dis/Gen Cost: 1.255/0.942 Minibatch: 0401 | Dis/Gen Cost: 1.181/1.007 Minibatch: 0601 | Dis/Gen Cost: 1.132/1.134 Minibatch: 0801 | Dis/Gen Cost: 1.170/1.249 Epoch: 0010 | Dis/Gen AvgCost: 1.328/0.922 Minibatch: 0001 | Dis/Gen Cost: 1.539/0.739 Minibatch: 0201 | Dis/Gen Cost: 1.181/1.186 Minibatch: 0401 | Dis/Gen Cost: 1.014/1.331 Minibatch: 0601 | Dis/Gen Cost: 1.380/0.884 Minibatch: 0801 | Dis/Gen Cost: 1.441/0.893 Epoch: 0011 | Dis/Gen AvgCost: 1.306/0.949 Minibatch: 0001 | Dis/Gen Cost: 1.248/0.953 Minibatch: 0201 | Dis/Gen Cost: 1.421/0.751 Minibatch: 0401 | Dis/Gen Cost: 1.323/0.891 Minibatch: 0601 | Dis/Gen Cost: 1.363/0.912 Minibatch: 0801 | Dis/Gen Cost: 1.174/1.112 Epoch: 0012 | Dis/Gen AvgCost: 1.334/0.931 Minibatch: 0001 | Dis/Gen Cost: 1.463/0.792 Minibatch: 0201 | Dis/Gen Cost: 1.296/0.992 Minibatch: 0401 | Dis/Gen Cost: 1.213/1.037 Minibatch: 0601 | Dis/Gen Cost: 1.273/0.899 Minibatch: 0801 | Dis/Gen Cost: 1.282/0.893 Epoch: 0013 | Dis/Gen AvgCost: 1.323/0.910 Minibatch: 0001 | Dis/Gen Cost: 1.192/0.921 Minibatch: 0201 | Dis/Gen Cost: 1.287/0.933 Minibatch: 0401 | Dis/Gen Cost: 1.292/0.898 Minibatch: 0601 | Dis/Gen Cost: 1.164/0.945 Minibatch: 0801 | Dis/Gen Cost: 1.469/0.776 Epoch: 0014 | Dis/Gen AvgCost: 1.312/0.890 Minibatch: 0001 | Dis/Gen Cost: 1.363/0.876 Minibatch: 0201 | Dis/Gen Cost: 1.398/0.759 Minibatch: 0401 | Dis/Gen Cost: 1.099/1.088 Minibatch: 0601 | Dis/Gen Cost: 1.415/0.831 Minibatch: 0801 | Dis/Gen Cost: 1.287/0.813 Epoch: 0015 | Dis/Gen AvgCost: 1.310/0.896 Minibatch: 0001 | Dis/Gen Cost: 1.309/0.910 Minibatch: 0201 | Dis/Gen Cost: 1.397/0.829 Minibatch: 0401 | Dis/Gen Cost: 1.221/0.949 Minibatch: 0601 | Dis/Gen Cost: 1.284/0.918 Minibatch: 0801 | Dis/Gen Cost: 1.315/0.737 Epoch: 0016 | Dis/Gen AvgCost: 1.306/0.860 Minibatch: 0001 | Dis/Gen Cost: 1.193/0.901 Minibatch: 0201 | Dis/Gen Cost: 1.339/0.908 Minibatch: 0401 | Dis/Gen Cost: 1.119/0.969 Minibatch: 0601 | Dis/Gen Cost: 1.293/0.907 Minibatch: 0801 | Dis/Gen Cost: 1.368/0.882 Epoch: 0017 | Dis/Gen AvgCost: 1.320/0.892 Minibatch: 0001 | Dis/Gen Cost: 1.308/1.014 Minibatch: 0201 | Dis/Gen Cost: 1.194/0.936 Minibatch: 0401 | Dis/Gen Cost: 1.536/0.755 Minibatch: 0601 | Dis/Gen Cost: 1.443/0.810 Minibatch: 0801 | Dis/Gen Cost: 1.288/0.730 Epoch: 0018 | Dis/Gen AvgCost: 1.315/0.867 Minibatch: 0001 | Dis/Gen Cost: 1.259/0.979 Minibatch: 0201 | Dis/Gen Cost: 1.307/0.822 Minibatch: 0401 | Dis/Gen Cost: 1.242/0.845 Minibatch: 0601 | Dis/Gen Cost: 1.422/0.891 Minibatch: 0801 | Dis/Gen Cost: 1.263/0.904 Epoch: 0019 | Dis/Gen AvgCost: 1.306/0.866 Minibatch: 0001 | Dis/Gen Cost: 1.204/0.811 Minibatch: 0201 | Dis/Gen Cost: 1.340/0.810 Minibatch: 0401 | Dis/Gen Cost: 1.278/0.963 Minibatch: 0601 | Dis/Gen Cost: 1.249/0.936 Minibatch: 0801 | Dis/Gen Cost: 1.285/0.945 Epoch: 0020 | Dis/Gen AvgCost: 1.316/0.853 Minibatch: 0001 | Dis/Gen Cost: 1.370/0.772 Minibatch: 0201 | Dis/Gen Cost: 1.478/0.762 Minibatch: 0401 | Dis/Gen Cost: 1.440/0.822 Minibatch: 0601 | Dis/Gen Cost: 1.269/0.809 Minibatch: 0801 | Dis/Gen Cost: 1.260/0.923 Epoch: 0021 | Dis/Gen AvgCost: 1.324/0.837 Minibatch: 0001 | Dis/Gen Cost: 1.401/0.892 Minibatch: 0201 | Dis/Gen Cost: 1.361/0.762 Minibatch: 0401 | Dis/Gen Cost: 1.121/1.012 Minibatch: 0601 | Dis/Gen Cost: 1.366/0.822 Minibatch: 0801 | Dis/Gen Cost: 1.484/0.744 Epoch: 0022 | Dis/Gen AvgCost: 1.314/0.851 Minibatch: 0001 | Dis/Gen Cost: 1.207/0.829 Minibatch: 0201 | Dis/Gen Cost: 1.320/0.786 Minibatch: 0401 | Dis/Gen Cost: 1.327/0.807 Minibatch: 0601 | Dis/Gen Cost: 1.250/0.909 Minibatch: 0801 | Dis/Gen Cost: 1.339/0.769 Epoch: 0023 | Dis/Gen AvgCost: 1.323/0.833 Minibatch: 0001 | Dis/Gen Cost: 1.363/0.825 Minibatch: 0201 | Dis/Gen Cost: 1.416/0.738 Minibatch: 0401 | Dis/Gen Cost: 1.290/0.876 Minibatch: 0601 | Dis/Gen Cost: 1.257/0.825 Minibatch: 0801 | Dis/Gen Cost: 1.510/0.633 Epoch: 0024 | Dis/Gen AvgCost: 1.323/0.841 Minibatch: 0001 | Dis/Gen Cost: 1.291/0.694 Minibatch: 0201 | Dis/Gen Cost: 1.400/0.720 Minibatch: 0401 | Dis/Gen Cost: 1.340/0.802 Minibatch: 0601 | Dis/Gen Cost: 1.339/0.784 Minibatch: 0801 | Dis/Gen Cost: 1.211/0.886 Epoch: 0025 | Dis/Gen AvgCost: 1.339/0.811 Minibatch: 0001 | Dis/Gen Cost: 1.395/0.865 Minibatch: 0201 | Dis/Gen Cost: 1.400/0.823 Minibatch: 0401 | Dis/Gen Cost: 1.357/0.811 Minibatch: 0601 | Dis/Gen Cost: 1.404/0.741 Minibatch: 0801 | Dis/Gen Cost: 1.298/0.930 Epoch: 0026 | Dis/Gen AvgCost: 1.340/0.819 Minibatch: 0001 | Dis/Gen Cost: 1.257/0.833 Minibatch: 0201 | Dis/Gen Cost: 1.359/0.772 Minibatch: 0401 | Dis/Gen Cost: 1.453/0.798 Minibatch: 0601 | Dis/Gen Cost: 1.389/0.853 Minibatch: 0801 | Dis/Gen Cost: 1.447/0.754 Epoch: 0027 | Dis/Gen AvgCost: 1.340/0.808 Minibatch: 0001 | Dis/Gen Cost: 1.353/0.764 Minibatch: 0201 | Dis/Gen Cost: 1.353/0.811 Minibatch: 0401 | Dis/Gen Cost: 1.458/0.748 Minibatch: 0601 | Dis/Gen Cost: 1.448/0.753 Minibatch: 0801 | Dis/Gen Cost: 1.475/0.696 Epoch: 0028 | Dis/Gen AvgCost: 1.349/0.792 Minibatch: 0001 | Dis/Gen Cost: 1.271/0.932 Minibatch: 0201 | Dis/Gen Cost: 1.294/0.894 Minibatch: 0401 | Dis/Gen Cost: 1.156/0.866 Minibatch: 0601 | Dis/Gen Cost: 1.292/0.778 Minibatch: 0801 | Dis/Gen Cost: 1.309/0.817 Epoch: 0029 | Dis/Gen AvgCost: 1.347/0.799 Minibatch: 0001 | Dis/Gen Cost: 1.459/0.727 Minibatch: 0201 | Dis/Gen Cost: 1.396/0.753 Minibatch: 0401 | Dis/Gen Cost: 1.367/0.754 Minibatch: 0601 | Dis/Gen Cost: 1.336/0.785 Minibatch: 0801 | Dis/Gen Cost: 1.304/0.756 Epoch: 0030 | Dis/Gen AvgCost: 1.347/0.780 Minibatch: 0001 | Dis/Gen Cost: 1.431/0.726 Minibatch: 0201 | Dis/Gen Cost: 1.348/0.793 Minibatch: 0401 | Dis/Gen Cost: 1.102/0.823 Minibatch: 0601 | Dis/Gen Cost: 1.276/0.772 Minibatch: 0801 | Dis/Gen Cost: 1.390/0.776 Epoch: 0031 | Dis/Gen AvgCost: 1.337/0.801 Minibatch: 0001 | Dis/Gen Cost: 1.507/0.704 Minibatch: 0201 | Dis/Gen Cost: 1.295/0.873 Minibatch: 0401 | Dis/Gen Cost: 1.312/0.835 Minibatch: 0601 | Dis/Gen Cost: 1.346/0.842 Minibatch: 0801 | Dis/Gen Cost: 1.328/0.721 Epoch: 0032 | Dis/Gen AvgCost: 1.342/0.792 Minibatch: 0001 | Dis/Gen Cost: 1.401/0.717 Minibatch: 0201 | Dis/Gen Cost: 1.436/0.737 Minibatch: 0401 | Dis/Gen Cost: 1.332/0.774 Minibatch: 0601 | Dis/Gen Cost: 1.311/0.804 Minibatch: 0801 | Dis/Gen Cost: 1.391/0.650 Epoch: 0033 | Dis/Gen AvgCost: 1.352/0.783 Minibatch: 0001 | Dis/Gen Cost: 1.317/0.740 Minibatch: 0201 | Dis/Gen Cost: 1.343/0.810 Minibatch: 0401 | Dis/Gen Cost: 1.394/0.717 Minibatch: 0601 | Dis/Gen Cost: 1.455/0.779 Minibatch: 0801 | Dis/Gen Cost: 1.445/0.704 Epoch: 0034 | Dis/Gen AvgCost: 1.348/0.785 Minibatch: 0001 | Dis/Gen Cost: 1.294/0.791 Minibatch: 0201 | Dis/Gen Cost: 1.277/0.886 Minibatch: 0401 | Dis/Gen Cost: 1.349/0.721 Minibatch: 0601 | Dis/Gen Cost: 1.297/0.717 Minibatch: 0801 | Dis/Gen Cost: 1.320/0.777 Epoch: 0035 | Dis/Gen AvgCost: 1.353/0.780 Minibatch: 0001 | Dis/Gen Cost: 1.338/0.756 Minibatch: 0201 | Dis/Gen Cost: 1.273/0.778 Minibatch: 0401 | Dis/Gen Cost: 1.325/0.865 Minibatch: 0601 | Dis/Gen Cost: 1.438/0.717 Minibatch: 0801 | Dis/Gen Cost: 1.328/0.785 Epoch: 0036 | Dis/Gen AvgCost: 1.352/0.770 Minibatch: 0001 | Dis/Gen Cost: 1.375/0.764 Minibatch: 0201 | Dis/Gen Cost: 1.453/0.723 Minibatch: 0401 | Dis/Gen Cost: 1.270/0.807 Minibatch: 0601 | Dis/Gen Cost: 1.392/0.775 Minibatch: 0801 | Dis/Gen Cost: 1.318/0.824 Epoch: 0037 | Dis/Gen AvgCost: 1.353/0.773 Minibatch: 0001 | Dis/Gen Cost: 1.270/0.874 Minibatch: 0201 | Dis/Gen Cost: 1.214/0.833 Minibatch: 0401 | Dis/Gen Cost: 1.456/0.666 Minibatch: 0601 | Dis/Gen Cost: 1.400/0.824 Minibatch: 0801 | Dis/Gen Cost: 1.328/0.736 Epoch: 0038 | Dis/Gen AvgCost: 1.354/0.776 Minibatch: 0001 | Dis/Gen Cost: 1.332/0.743 Minibatch: 0201 | Dis/Gen Cost: 1.389/0.710 Minibatch: 0401 | Dis/Gen Cost: 1.375/0.708 Minibatch: 0601 | Dis/Gen Cost: 1.296/0.758 Minibatch: 0801 | Dis/Gen Cost: 1.337/0.783 Epoch: 0039 | Dis/Gen AvgCost: 1.356/0.765 Minibatch: 0001 | Dis/Gen Cost: 1.388/0.706 Minibatch: 0201 | Dis/Gen Cost: 1.371/0.712 Minibatch: 0401 | Dis/Gen Cost: 1.349/0.698 Minibatch: 0601 | Dis/Gen Cost: 1.380/0.723 Minibatch: 0801 | Dis/Gen Cost: 1.371/0.746 Epoch: 0040 | Dis/Gen AvgCost: 1.358/0.759 Minibatch: 0001 | Dis/Gen Cost: 1.349/0.702 Minibatch: 0201 | Dis/Gen Cost: 1.315/0.742 Minibatch: 0401 | Dis/Gen Cost: 1.353/0.760 Minibatch: 0601 | Dis/Gen Cost: 1.335/0.799 Minibatch: 0801 | Dis/Gen Cost: 1.403/0.726 Epoch: 0041 | Dis/Gen AvgCost: 1.362/0.755 Minibatch: 0001 | Dis/Gen Cost: 1.363/0.782 Minibatch: 0201 | Dis/Gen Cost: 1.335/0.742 Minibatch: 0401 | Dis/Gen Cost: 1.344/0.751 Minibatch: 0601 | Dis/Gen Cost: 1.338/0.740 Minibatch: 0801 | Dis/Gen Cost: 1.460/0.735 Epoch: 0042 | Dis/Gen AvgCost: 1.361/0.764 Minibatch: 0001 | Dis/Gen Cost: 1.308/0.767 Minibatch: 0201 | Dis/Gen Cost: 1.367/0.764 Minibatch: 0401 | Dis/Gen Cost: 1.382/0.764 Minibatch: 0601 | Dis/Gen Cost: 1.419/0.625 Minibatch: 0801 | Dis/Gen Cost: 1.393/0.777 Epoch: 0043 | Dis/Gen AvgCost: 1.361/0.753 Minibatch: 0001 | Dis/Gen Cost: 1.413/0.749 Minibatch: 0201 | Dis/Gen Cost: 1.370/0.724 Minibatch: 0401 | Dis/Gen Cost: 1.314/0.756 Minibatch: 0601 | Dis/Gen Cost: 1.321/0.763 Minibatch: 0801 | Dis/Gen Cost: 1.354/0.771 Epoch: 0044 | Dis/Gen AvgCost: 1.364/0.752 Minibatch: 0001 | Dis/Gen Cost: 1.363/0.748 Minibatch: 0201 | Dis/Gen Cost: 1.365/0.727 Minibatch: 0401 | Dis/Gen Cost: 1.439/0.714 Minibatch: 0601 | Dis/Gen Cost: 1.429/0.696 Minibatch: 0801 | Dis/Gen Cost: 1.427/0.699 Epoch: 0045 | Dis/Gen AvgCost: 1.363/0.745 Minibatch: 0001 | Dis/Gen Cost: 1.398/0.713 Minibatch: 0201 | Dis/Gen Cost: 1.408/0.717 Minibatch: 0401 | Dis/Gen Cost: 1.298/0.734 Minibatch: 0601 | Dis/Gen Cost: 1.345/0.805 Minibatch: 0801 | Dis/Gen Cost: 1.331/0.828 Epoch: 0046 | Dis/Gen AvgCost: 1.366/0.752 Minibatch: 0001 | Dis/Gen Cost: 1.319/0.751 Minibatch: 0201 | Dis/Gen Cost: 1.482/0.713 Minibatch: 0401 | Dis/Gen Cost: 1.341/0.803 Minibatch: 0601 | Dis/Gen Cost: 1.386/0.651 Minibatch: 0801 | Dis/Gen Cost: 1.428/0.701 Epoch: 0047 | Dis/Gen AvgCost: 1.369/0.758 Minibatch: 0001 | Dis/Gen Cost: 1.378/0.747 Minibatch: 0201 | Dis/Gen Cost: 1.355/0.716 Minibatch: 0401 | Dis/Gen Cost: 1.357/0.686 Minibatch: 0601 | Dis/Gen Cost: 1.333/0.767 Minibatch: 0801 | Dis/Gen Cost: 1.380/0.712 Epoch: 0048 | Dis/Gen AvgCost: 1.370/0.735 Minibatch: 0001 | Dis/Gen Cost: 1.409/0.706 Minibatch: 0201 | Dis/Gen Cost: 1.307/0.789 Minibatch: 0401 | Dis/Gen Cost: 1.396/0.731 Minibatch: 0601 | Dis/Gen Cost: 1.375/0.711 Minibatch: 0801 | Dis/Gen Cost: 1.365/0.782 Epoch: 0049 | Dis/Gen AvgCost: 1.371/0.733 Minibatch: 0001 | Dis/Gen Cost: 1.409/0.701 Minibatch: 0201 | Dis/Gen Cost: 1.369/0.728 Minibatch: 0401 | Dis/Gen Cost: 1.315/0.730 Minibatch: 0601 | Dis/Gen Cost: 1.321/0.774 Minibatch: 0801 | Dis/Gen Cost: 1.336/0.735 Epoch: 0050 | Dis/Gen AvgCost: 1.372/0.735
%matplotlib inline
import matplotlib.pyplot as plt
plt.plot(range(len(avg_costs['discriminator'])),
avg_costs['discriminator'], label='discriminator')
plt.plot(range(len(avg_costs['generator'])),
avg_costs['generator'], label='generator')
plt.legend()
plt.show()
####################################
### RELOAD & GENERATE SAMPLE IMAGES
####################################
n_examples = 25
with tf.Session(graph=g) as sess:
saver.restore(sess, save_path='./gan-conv.ckpt')
batch_randsample = np.random.uniform(-1, 1, size=(n_examples, gen_input_size))
new_examples = sess.run('generator/generator_outputs:0',
feed_dict={'generator_inputs:0': batch_randsample,
'dropout:0': 0.0,
'is_training:0': False})
fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(8, 8),
sharey=True, sharex=True)
for image, ax in zip(new_examples, axes.flatten()):
ax.imshow(image.reshape((dis_input_size // 28, dis_input_size // 28)), cmap='binary')
plt.show()
WARNING:tensorflow:From /home/raschka/miniconda3/lib/python3.7/site-packages/tensorflow/python/training/saver.py:1266: checkpoint_exists (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version. Instructions for updating: Use standard file APIs to check for files with this prefix. INFO:tensorflow:Restoring parameters from ./gan-conv.ckpt