from keras.layers import Input, Dense, Reshape, Flatten
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D, MaxPooling2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
from sklearn.utils import shuffle
import matplotlib.pyplot as plt
import numpy as np
import os
from PIL import Image
import skimage.transform as st
Using TensorFlow backend.
# Folder containing input (low resolution) dataset
input_path = r'D:\Downloads\selfie2anime\trainB'
# Folder containing output (high resolution) dataset
output_path = r'D:\Downloads\selfie2anime\trainB'
# Dimensions of the images inside the dataset.
# NOTE: The image sizes must be compatible meaning output_dimensions / input_dimensions is a multiple of 2
input_dimensions = (128,128,3)
# Dimensions of the images inside the dataset.
# NOTE: The image sizes must be compatible meaning output_dimensions / input_dimensions is a multiple of 2
output_dimensions = (256,256,3)
# How many times to increase the resolution by 2 (by appling the UpSampling2D layer)
super_sampling_ratio = int(output_dimensions[0] / input_dimensions[0] / 2)
# Folder where you want to save to model as well as generated samples
model_path = r"C:\Users\Vee\Desktop\python\GAN\DLSS\results"
# How many epochs between saving your model
interval = 5
# How many epochs to train the model
epoch = 100
# How many images to train at one time. Ideally this number would be a factor of the size of your dataset
batch = 25
# How many convolutional filters for each convolutional layer of the generator and the discrminator
conv_filters = 64
# Size of kernel used in the convolutional layers
kernel = (5,5)
# Boolean flag, set to True if the data has pngs to remove alpha layer from images
png = True
class DCGAN():
# Initialize parameters, generator, and discriminator models
def __init__(self):
# Set dimensions of the output image
self.img_rows = output_dimensions[0]
self.img_cols = output_dimensions[1]
self.channels = output_dimensions[2]
self.img_shape = (self.img_rows, self.img_cols, self.channels)
# Shape of low resolution input image
self.latent_dim = input_dimensions
# Chose optimizer for the models
optimizer = Adam(0.0002, 0.5)
# Build and compile the discriminator
self.discriminator = self.build_discriminator()
self.discriminator.compile(loss='binary_crossentropy',
optimizer=optimizer,
metrics=['accuracy'])
# Build the generator
self.generator = self.build_generator()
generator = self.generator
# The generator takes low resolution images as input and generates high resolution images
z = Input(shape = self.latent_dim)
img = self.generator(z)
# For the combined model we will only train the generator
self.discriminator.trainable = False
# The discriminator takes generated images as input and determines validity
valid = self.discriminator(img)
# The combined model (stacked generator and discriminator)
# Trains the generator to fool the discriminator
self.combined = Model(z, valid)
self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)
# load data from specified file path
def load_data(self):
# Initializing arrays for data and image file paths
data = []
small = []
paths = []
# Get the file paths of all image files in this folder
for r, d, f in os.walk(output_path):
for file in f:
if '.jpg' in file or 'png' in file:
paths.append(os.path.join(r, file))
# For each file add high resolution image to array
for path in paths:
img = Image.open(path)
# Resize Image
y = np.array(img.resize((self.img_rows,self.img_cols)))
# Remove alpha layer if imgaes are PNG
if(png):
y = y[...,:3]
data.append(y)
paths = []
# Get the file paths of all image files in this folder
for r, d, f in os.walk(input_path):
for file in f:
if '.jpg' in file or 'png' in file:
paths.append(os.path.join(r, file))
# For each file add low resolution image to array
for path in paths:
img = Image.open(path)
# Resize Image
x = np.array(img.resize((self.latent_dim[0],self.latent_dim[1])))
# Remove alpha layer if imgaes are PNG
if(png):
x = x[...,:3]
small.append(x)
# Return x_train and y_train reshaped to 4 dimensions
y_train = np.array(data)
y_train = y_train.reshape(len(data),self.img_rows,self.img_cols,self.channels)
x_train = np.array(small)
x_train = x_train.reshape(len(small),self.latent_dim[0],self.latent_dim[1],self.latent_dim[2])
del data
del small
del paths
# Shuffle indexes of data
X_shuffle, Y_shuffle = shuffle(x_train, y_train)
return X_shuffle, Y_shuffle
# Define Generator model
def build_generator(self):
model = Sequential()
# 1st Convolutional Layer / Input Layer
model.add(Conv2D(conv_filters, kernel_size=kernel, padding="same", input_shape=self.latent_dim))
model.add(LeakyReLU(alpha=0.2))
# Upsample the data as many times as needed to reach output resolution
for i in range(super_sampling_ratio):
# Super Sampling Convolutional Layer
model.add(Conv2D(conv_filters, kernel_size=kernel, padding="same"))
model.add(LeakyReLU(alpha=0.2))
# Upsample the data (Double the resolution)
model.add(UpSampling2D())
# Convolutional Layer
model.add(Conv2D(conv_filters, kernel_size=kernel, padding="same"))
model.add(LeakyReLU(alpha=0.2))
# Convolutional Layer
model.add(Conv2D(conv_filters, kernel_size=kernel, padding="same"))
model.add(LeakyReLU(alpha=0.2))
# Final Convolutional Layer (Output Layer)
model.add(Conv2D(3, kernel_size=kernel, padding="same"))
model.add(LeakyReLU(alpha=0.2))
model.summary()
noise = Input(shape=self.latent_dim)
img = model(noise)
return Model(noise, img)
# Define Discriminator model
def build_discriminator(self):
model = Sequential()
# Input Layer
model.add(Conv2D(conv_filters, kernel_size=kernel, input_shape=self.img_shape,activation = "relu", padding="same"))
# Downsample the image as many times as needed
for i in range(super_sampling_ratio):
# Convolutional Layer
model.add(Conv2D(conv_filters, kernel_size=kernel))
model.add(LeakyReLU(alpha=0.2))
# Downsample the data (Half the resolution)
model.add(MaxPooling2D(pool_size=(2, 2)))
# Convolutional Layer
model.add(Conv2D(conv_filters, kernel_size=kernel, strides = 2))
model.add(LeakyReLU(alpha=0.2))
# Convolutional Layer
model.add(Conv2D(conv_filters, kernel_size=kernel, strides = 2))
model.add(LeakyReLU(alpha=0.2))
model.add(Flatten())
# Output Layer
model.add(Dense(1, activation='sigmoid'))
model.summary()
img = Input(shape=self.img_shape)
validity = model(img)
return Model(img, validity)
# Train the Generative Adversarial Network
def train(self, epochs, batch_size, save_interval):
# Prevent script from crashing from bad user input
if(epochs <= 0):
epochs = 1
if(batch_size <= 0):
batch_size = 1
# Load the dataset
X_train, Y_train = self.load_data()
# Normalizing data to be between 0 and 1
X_train = X_train / 255
Y_train = Y_train / 255
# Adversarial ground truths
valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))
# Placeholder arrays for Loss function values
g_loss_epochs = np.zeros((epochs, 1))
d_loss_epochs = np.zeros((epochs, 1))
# Training the GAN
for epoch in range(1, epochs + 1):
# Initialize indexes for training data
start = 0
end = start + batch_size
# Array to sum up all loss function values
discriminator_loss_real = []
discriminator_loss_fake = []
generator_loss = []
# Iterate through dataset training one batch at a time
for i in range(int(len(X_train)/batch_size)):
# Get batch of images
imgs_output = Y_train[start:end]
imgs_input = X_train[start:end]
# Train Discriminator
# Make predictions on current batch using generator
gen_imgs = self.generator.predict(imgs_input)
# Train the discriminator (real classified as ones and generated as zero)
d_loss_real = self.discriminator.train_on_batch(imgs_output, valid)
d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
# Train Generator
# Train the generator (wants discriminator to mistake images as real)
g_loss = self.combined.train_on_batch(imgs_input, valid)
# Add loss for current batch to sum over entire epoch
discriminator_loss_real.append(d_loss[0])
discriminator_loss_fake.append(d_loss[1])
generator_loss.append(g_loss)
# Increment image indexes
start = start + batch_size
end = end + batch_size
# Get average loss over the entire epoch
loss_data = [np.average(discriminator_loss_real),np.average(discriminator_loss_fake),np.average(generator_loss)]
#save loss history
g_loss_epochs[epoch - 1] = loss_data[2]
# Average loss of real data classification and fake data accuracy
d_loss_epochs[epoch - 1] = (loss_data[0] + (1 - loss_data[1])) / 2
# Print average loss over current epoch
print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, loss_data[0], loss_data[1]*100, loss_data[2]))
# If epoch is at interval, save model and generate image samples
if epoch % save_interval == 0:
# Select 8 random indexes
idx = np.random.randint(0, X_train.shape[0], 8)
# Get batch of training images
x_points = X_train[idx]
# Make predictions on batch of training images
predicted_imgs = self.generator.predict(x_points)
# Undo normalization of data. Update values to be between 0 and 255 for RGB image
predicted_imgs = np.array(predicted_imgs) * 255
np.clip(predicted_imgs, 0, 255, out=predicted_imgs)
predicted_imgs = predicted_imgs.astype('uint8')
x_points = np.array(x_points) * 255
np.clip(x_points, 0, 255, out=x_points)
x_points = x_points.astype('uint8')
interpolated_imgs = []
# Interpolate low resolution images for comparison
for x in range(len(x_points)):
img = Image.fromarray(x_points[x])
interpolated_imgs.append(np.array(img.resize((self.img_rows,self.img_cols))))
# Plot the predictions next to the interpolated images
self.save_imgs(epoch, predicted_imgs, interpolated_imgs)
return g_loss_epochs, d_loss_epochs
# Save the model and generate prediction samples for a given epoch
def save_imgs(self, epoch, gen_imgs, interpolated):
# Define number of columns and rows
r, c = 4, 4
# Placeholder array for MatPlotLib Figure Subplots
subplots = []
# Create figure with title
fig = plt.figure(figsize= (40, 40))
fig.suptitle("Epoch: " + str(epoch), fontsize=65)
# Initialize counters needed to track indexes across multiple arrays
img_count = 0;
index_count = 0;
x_count = 0;
# Loop through columns and rows of the figure
for i in range(1, c+1):
for j in range(1, r+1):
# If row is even, plot the predictions
if(j % 2 == 0):
img = gen_imgs[index_count]
index_count = index_count + 1
# If row is odd, plot the interpolated images
else:
img = interpolated[x_count]
x_count = x_count + 1
# Add image to figure, add subplot to array
subplots.append(fig.add_subplot(r, c, img_count + 1))
plt.imshow(img)
img_count = img_count + 1
# Add title to columns of figure
subplots[0].set_title("Interpolated", fontsize=45)
subplots[1].set_title("Predicted", fontsize=45)
subplots[2].set_title("Interpolated", fontsize=45)
subplots[3].set_title("Predicted", fontsize=45)
# Save figure to .png image in specified folder
fig.savefig(model_path + "\\epoch_%d.png" % epoch)
plt.close()
# save model to .h5 file in specified folder
self.generator.save(model_path + "\\generator" + str(epoch) + ".h5")
dcgan = DCGAN()
WARNING:tensorflow:From C:\Users\Vee\Anaconda3\lib\site-packages\keras\backend\tensorflow_backend.py:4070: The name tf.nn.max_pool is deprecated. Please use tf.nn.max_pool2d instead. Model: "sequential_1" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= conv2d_1 (Conv2D) (None, 256, 256, 64) 4864 _________________________________________________________________ conv2d_2 (Conv2D) (None, 252, 252, 64) 102464 _________________________________________________________________ leaky_re_lu_1 (LeakyReLU) (None, 252, 252, 64) 0 _________________________________________________________________ max_pooling2d_1 (MaxPooling2 (None, 126, 126, 64) 0 _________________________________________________________________ conv2d_3 (Conv2D) (None, 61, 61, 64) 102464 _________________________________________________________________ leaky_re_lu_2 (LeakyReLU) (None, 61, 61, 64) 0 _________________________________________________________________ conv2d_4 (Conv2D) (None, 29, 29, 64) 102464 _________________________________________________________________ leaky_re_lu_3 (LeakyReLU) (None, 29, 29, 64) 0 _________________________________________________________________ flatten_1 (Flatten) (None, 53824) 0 _________________________________________________________________ dense_1 (Dense) (None, 1) 53825 ================================================================= Total params: 366,081 Trainable params: 366,081 Non-trainable params: 0 _________________________________________________________________ WARNING:tensorflow:From C:\Users\Vee\Anaconda3\lib\site-packages\tensorflow\python\ops\nn_impl.py:180: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version. Instructions for updating: Use tf.where in 2.0, which has the same broadcast rule as np.where Model: "sequential_2" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= conv2d_5 (Conv2D) (None, 128, 128, 64) 4864 _________________________________________________________________ leaky_re_lu_4 (LeakyReLU) (None, 128, 128, 64) 0 _________________________________________________________________ conv2d_6 (Conv2D) (None, 128, 128, 64) 102464 _________________________________________________________________ leaky_re_lu_5 (LeakyReLU) (None, 128, 128, 64) 0 _________________________________________________________________ up_sampling2d_1 (UpSampling2 (None, 256, 256, 64) 0 _________________________________________________________________ conv2d_7 (Conv2D) (None, 256, 256, 64) 102464 _________________________________________________________________ leaky_re_lu_6 (LeakyReLU) (None, 256, 256, 64) 0 _________________________________________________________________ conv2d_8 (Conv2D) (None, 256, 256, 64) 102464 _________________________________________________________________ leaky_re_lu_7 (LeakyReLU) (None, 256, 256, 64) 0 _________________________________________________________________ conv2d_9 (Conv2D) (None, 256, 256, 3) 4803 _________________________________________________________________ leaky_re_lu_8 (LeakyReLU) (None, 256, 256, 3) 0 ================================================================= Total params: 317,059 Trainable params: 317,059 Non-trainable params: 0 _________________________________________________________________
g_loss, d_loss = dcgan.train(epochs=epoch, batch_size=batch, save_interval=interval)
WARNING:tensorflow:From C:\Users\Vee\Anaconda3\lib\site-packages\keras\backend\tensorflow_backend.py:422: The name tf.global_variables is deprecated. Please use tf.compat.v1.global_variables instead.
C:\Users\Vee\Anaconda3\lib\site-packages\keras\engine\training.py:297: UserWarning: Discrepancy between trainable weights and collected trainable weights, did you set `model.trainable` without calling `model.compile` after ? 'Discrepancy between trainable weights and collected trainable' C:\Users\Vee\Anaconda3\lib\site-packages\keras\engine\training.py:297: UserWarning: Discrepancy between trainable weights and collected trainable weights, did you set `model.trainable` without calling `model.compile` after ? 'Discrepancy between trainable weights and collected trainable' C:\Users\Vee\Anaconda3\lib\site-packages\keras\engine\training.py:297: UserWarning: Discrepancy between trainable weights and collected trainable weights, did you set `model.trainable` without calling `model.compile` after ? 'Discrepancy between trainable weights and collected trainable'
1 [D loss: 0.640689, acc.: 59.37%] [G loss: 0.967596] 2 [D loss: 0.575859, acc.: 73.34%] [G loss: 1.787223] 3 [D loss: 0.656025, acc.: 61.31%] [G loss: 1.042790] 4 [D loss: 0.656616, acc.: 60.19%] [G loss: 0.998186] 5 [D loss: 0.674997, acc.: 56.04%] [G loss: 0.893507] 6 [D loss: 0.679668, acc.: 54.03%] [G loss: 0.846747] 7 [D loss: 0.676826, acc.: 55.59%] [G loss: 0.830254] 8 [D loss: 0.670259, acc.: 58.06%] [G loss: 0.842488] 9 [D loss: 0.669807, acc.: 57.94%] [G loss: 0.857547] 10 [D loss: 0.666385, acc.: 58.97%] [G loss: 0.850146] 11 [D loss: 0.656822, acc.: 60.53%] [G loss: 0.901498] 12 [D loss: 0.657579, acc.: 59.84%] [G loss: 0.917162] 13 [D loss: 0.652214, acc.: 60.60%] [G loss: 0.939749] 14 [D loss: 0.657371, acc.: 59.03%] [G loss: 0.924045] 15 [D loss: 0.650892, acc.: 60.90%] [G loss: 0.937750] 16 [D loss: 0.643743, acc.: 61.19%] [G loss: 0.967408] 17 [D loss: 0.636638, acc.: 61.35%] [G loss: 1.002258] 18 [D loss: 0.638683, acc.: 61.13%] [G loss: 1.005423] 19 [D loss: 0.621597, acc.: 63.82%] [G loss: 1.074563] 20 [D loss: 0.610987, acc.: 64.68%] [G loss: 1.145303] 21 [D loss: 0.606416, acc.: 65.56%] [G loss: 1.137956] 22 [D loss: 0.606955, acc.: 64.84%] [G loss: 1.165644] 23 [D loss: 0.588689, acc.: 67.40%] [G loss: 1.233381] 24 [D loss: 0.579637, acc.: 67.49%] [G loss: 1.265406] 25 [D loss: 0.573944, acc.: 68.12%] [G loss: 1.295849] 26 [D loss: 0.571993, acc.: 69.04%] [G loss: 1.321567] 27 [D loss: 0.551823, acc.: 70.21%] [G loss: 1.386042] 28 [D loss: 0.605937, acc.: 68.32%] [G loss: 1.383290] 29 [D loss: 0.576366, acc.: 68.07%] [G loss: 1.229301] 30 [D loss: 0.542845, acc.: 71.46%] [G loss: 1.375403] 31 [D loss: 0.551456, acc.: 70.26%] [G loss: 1.375835] 32 [D loss: 0.522566, acc.: 72.63%] [G loss: 1.527404] 33 [D loss: 0.514937, acc.: 73.88%] [G loss: 1.573746] 34 [D loss: 0.497715, acc.: 75.21%] [G loss: 1.622509] 35 [D loss: 0.524672, acc.: 72.65%] [G loss: 1.569490] 36 [D loss: 0.502336, acc.: 74.65%] [G loss: 1.649279] 37 [D loss: 0.510152, acc.: 73.51%] [G loss: 1.655877] 38 [D loss: 0.510392, acc.: 74.03%] [G loss: 1.677173] 39 [D loss: 0.496750, acc.: 74.84%] [G loss: 1.746743] 40 [D loss: 0.451783, acc.: 78.25%] [G loss: 1.810596] 41 [D loss: 0.491769, acc.: 75.01%] [G loss: 1.771845] 42 [D loss: 0.447622, acc.: 78.26%] [G loss: 1.894888] 43 [D loss: 0.446278, acc.: 78.22%] [G loss: 1.977968] 44 [D loss: 0.437225, acc.: 79.00%] [G loss: 2.043318] 45 [D loss: 0.467827, acc.: 77.13%] [G loss: 2.040174] 46 [D loss: 0.417360, acc.: 80.93%] [G loss: 2.098764] 47 [D loss: 0.422721, acc.: 79.82%] [G loss: 2.187025] 48 [D loss: 0.433118, acc.: 79.51%] [G loss: 2.227528] 49 [D loss: 0.417002, acc.: 80.82%] [G loss: 2.244645] 50 [D loss: 0.400861, acc.: 81.25%] [G loss: 2.313459] 51 [D loss: 0.415689, acc.: 81.13%] [G loss: 2.429186] 52 [D loss: 0.436793, acc.: 79.43%] [G loss: 2.159538] 53 [D loss: 0.389505, acc.: 82.99%] [G loss: 2.444354] 54 [D loss: 0.394576, acc.: 81.76%] [G loss: 2.437955] 55 [D loss: 0.374974, acc.: 83.63%] [G loss: 2.567042] 56 [D loss: 0.380934, acc.: 83.22%] [G loss: 2.477628] 57 [D loss: 0.368564, acc.: 83.31%] [G loss: 2.652614] 58 [D loss: 0.350250, acc.: 84.47%] [G loss: 2.567553] 59 [D loss: 0.345719, acc.: 84.59%] [G loss: 2.724653] 60 [D loss: 0.361756, acc.: 83.82%] [G loss: 2.789159] 61 [D loss: 0.414287, acc.: 81.19%] [G loss: 2.715601] 62 [D loss: 0.367640, acc.: 84.15%] [G loss: 2.813440] 63 [D loss: 0.377639, acc.: 84.16%] [G loss: 2.799419] 64 [D loss: 0.310042, acc.: 87.12%] [G loss: 3.087482] 65 [D loss: 0.288239, acc.: 87.96%] [G loss: 3.048853] 66 [D loss: 0.298288, acc.: 86.60%] [G loss: 2.982275] 67 [D loss: 0.322462, acc.: 86.63%] [G loss: 3.216583] 68 [D loss: 0.380961, acc.: 83.46%] [G loss: 3.188553] 69 [D loss: 0.287852, acc.: 88.16%] [G loss: 3.171818] 70 [D loss: 0.303094, acc.: 87.15%] [G loss: 3.223719] 71 [D loss: 0.282329, acc.: 88.57%] [G loss: 3.229036] 72 [D loss: 0.343691, acc.: 85.90%] [G loss: 3.246882] 73 [D loss: 0.410705, acc.: 82.43%] [G loss: 2.957000] 74 [D loss: 0.412137, acc.: 82.82%] [G loss: 2.990260] 75 [D loss: 0.294779, acc.: 88.57%] [G loss: 3.188590] 76 [D loss: 0.256216, acc.: 89.97%] [G loss: 3.230703] 77 [D loss: 0.237346, acc.: 91.15%] [G loss: 3.410813] 78 [D loss: 0.241709, acc.: 90.76%] [G loss: 3.507348] 79 [D loss: 0.354374, acc.: 85.18%] [G loss: 3.280726] 80 [D loss: 0.301541, acc.: 87.74%] [G loss: 3.452413] 81 [D loss: 0.284747, acc.: 88.74%] [G loss: 3.672526] 82 [D loss: 0.278574, acc.: 88.66%] [G loss: 3.628673] 83 [D loss: 0.237505, acc.: 90.94%] [G loss: 3.602053] 84 [D loss: 0.245658, acc.: 90.38%] [G loss: 3.760730] 85 [D loss: 0.263280, acc.: 90.10%] [G loss: 3.655126] 86 [D loss: 0.249338, acc.: 90.03%] [G loss: 3.659350] 87 [D loss: 0.237923, acc.: 90.82%] [G loss: 3.497142] 88 [D loss: 0.198672, acc.: 93.13%] [G loss: 3.920542] 89 [D loss: 0.223380, acc.: 91.35%] [G loss: 4.056170] 90 [D loss: 0.217946, acc.: 91.57%] [G loss: 4.005002] 91 [D loss: 0.382039, acc.: 84.87%] [G loss: 3.712770] 92 [D loss: 0.246974, acc.: 90.62%] [G loss: 3.973697] 93 [D loss: 0.212904, acc.: 92.16%] [G loss: 3.917145] 94 [D loss: 0.252661, acc.: 90.18%] [G loss: 3.862247] 95 [D loss: 0.208923, acc.: 92.53%] [G loss: 4.122257] 96 [D loss: 0.261324, acc.: 89.57%] [G loss: 4.188636] 97 [D loss: 0.209971, acc.: 92.47%] [G loss: 4.020344] 98 [D loss: 0.223462, acc.: 91.79%] [G loss: 4.151647] 99 [D loss: 0.216101, acc.: 92.07%] [G loss: 4.292317] 100 [D loss: 0.284887, acc.: 88.88%] [G loss: 4.267109]
plt.plot(g_loss)
plt.plot(d_loss)
plt.title('GAN Loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Generator', 'Discriminator'], loc='upper left')
plt.show()