import tensorflow as tf
from keras import backend as K
import matplotlib.pyplot as plt
import numpy as np
from imageio import mimsave
from IPython.display import display as display_fn
from IPython.display import Image, clear_output
from pathlib import Path
content_path = Path('content_image.jpg')
style_path = Path('style_image.jpg')
def load_img(path_to_img):
'''loads an image as a tensor and scales it to 512 pixels'''
max_dim = 512
image = tf.io.read_file(path_to_img)
image = tf.image.decode_jpeg(image)
image = tf.image.convert_image_dtype(image, tf.float32)
shape = tf.shape(image)[:-1]
shape = tf.cast(tf.shape(image)[:-1], tf.float32)
long_dim = max(shape)
scale = max_dim / long_dim
new_shape = tf.cast(shape * scale, tf.int32)
image = tf.image.resize(image, new_shape)
image = image[tf.newaxis, :]
image = tf.image.convert_image_dtype(image, tf.uint8)
return image
def load_images(content_path, style_path):
'''loads the content and path images as tensors'''
content_image = load_img("{}".format(content_path))
style_image = load_img("{}".format(style_path))
return content_image, style_image
def imshow(image, title=None):
'''displays an image with a corresponding title'''
if len(image.shape) > 3:
image = tf.squeeze(image, axis=0)
plt.imshow(image)
if title:
plt.title(title)
def show_images_with_objects(images, titles=[]):
'''displays a row of images with corresponding titles'''
if len(images) != len(titles):
return
plt.figure(figsize=(20, 12))
for idx, (image, title) in enumerate(zip(images, titles)):
plt.subplot(1, len(images), idx + 1)
plt.xticks([])
plt.yticks([])
imshow(image, title)
# display the content and style image
content_image, style_image = load_images(content_path, style_path)
show_images_with_objects([content_image, style_image],
titles=[f'content image: {content_path}',
f'style image: {style_path}'])
We will inspect the layers of the Inception model.
# clear session to make layer naming consistent when re-running this cell
K.clear_session()
# download the inception model and inspect the layers
tmp_inception = tf.keras.applications.InceptionV3()
tmp_inception.summary()
# delete temporary model
del tmp_inception
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/inception_v3/inception_v3_weights_tf_dim_ordering_tf_kernels.h5 96116736/96112376 [==============================] - 0s 0us/step 96124928/96112376 [==============================] - 0s 0us/step Model: "inception_v3" __________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== input_1 (InputLayer) [(None, 299, 299, 3 0 [] )] conv2d (Conv2D) (None, 149, 149, 32 864 ['input_1[0][0]'] ) batch_normalization (BatchNorm (None, 149, 149, 32 96 ['conv2d[0][0]'] alization) ) activation (Activation) (None, 149, 149, 32 0 ['batch_normalization[0][0]'] ) conv2d_1 (Conv2D) (None, 147, 147, 32 9216 ['activation[0][0]'] ) batch_normalization_1 (BatchNo (None, 147, 147, 32 96 ['conv2d_1[0][0]'] rmalization) ) activation_1 (Activation) (None, 147, 147, 32 0 ['batch_normalization_1[0][0]'] ) conv2d_2 (Conv2D) (None, 147, 147, 64 18432 ['activation_1[0][0]'] ) batch_normalization_2 (BatchNo (None, 147, 147, 64 192 ['conv2d_2[0][0]'] rmalization) ) activation_2 (Activation) (None, 147, 147, 64 0 ['batch_normalization_2[0][0]'] ) max_pooling2d (MaxPooling2D) (None, 73, 73, 64) 0 ['activation_2[0][0]'] conv2d_3 (Conv2D) (None, 73, 73, 80) 5120 ['max_pooling2d[0][0]'] batch_normalization_3 (BatchNo (None, 73, 73, 80) 240 ['conv2d_3[0][0]'] rmalization) activation_3 (Activation) (None, 73, 73, 80) 0 ['batch_normalization_3[0][0]'] conv2d_4 (Conv2D) (None, 71, 71, 192) 138240 ['activation_3[0][0]'] batch_normalization_4 (BatchNo (None, 71, 71, 192) 576 ['conv2d_4[0][0]'] rmalization) activation_4 (Activation) (None, 71, 71, 192) 0 ['batch_normalization_4[0][0]'] max_pooling2d_1 (MaxPooling2D) (None, 35, 35, 192) 0 ['activation_4[0][0]'] conv2d_8 (Conv2D) (None, 35, 35, 64) 12288 ['max_pooling2d_1[0][0]'] batch_normalization_8 (BatchNo (None, 35, 35, 64) 192 ['conv2d_8[0][0]'] rmalization) activation_8 (Activation) (None, 35, 35, 64) 0 ['batch_normalization_8[0][0]'] conv2d_6 (Conv2D) (None, 35, 35, 48) 9216 ['max_pooling2d_1[0][0]'] conv2d_9 (Conv2D) (None, 35, 35, 96) 55296 ['activation_8[0][0]'] batch_normalization_6 (BatchNo (None, 35, 35, 48) 144 ['conv2d_6[0][0]'] rmalization) batch_normalization_9 (BatchNo (None, 35, 35, 96) 288 ['conv2d_9[0][0]'] rmalization) activation_6 (Activation) (None, 35, 35, 48) 0 ['batch_normalization_6[0][0]'] activation_9 (Activation) (None, 35, 35, 96) 0 ['batch_normalization_9[0][0]'] average_pooling2d (AveragePool (None, 35, 35, 192) 0 ['max_pooling2d_1[0][0]'] ing2D) conv2d_5 (Conv2D) (None, 35, 35, 64) 12288 ['max_pooling2d_1[0][0]'] conv2d_7 (Conv2D) (None, 35, 35, 64) 76800 ['activation_6[0][0]'] conv2d_10 (Conv2D) (None, 35, 35, 96) 82944 ['activation_9[0][0]'] conv2d_11 (Conv2D) (None, 35, 35, 32) 6144 ['average_pooling2d[0][0]'] batch_normalization_5 (BatchNo (None, 35, 35, 64) 192 ['conv2d_5[0][0]'] rmalization) batch_normalization_7 (BatchNo (None, 35, 35, 64) 192 ['conv2d_7[0][0]'] rmalization) batch_normalization_10 (BatchN (None, 35, 35, 96) 288 ['conv2d_10[0][0]'] ormalization) batch_normalization_11 (BatchN (None, 35, 35, 32) 96 ['conv2d_11[0][0]'] ormalization) activation_5 (Activation) (None, 35, 35, 64) 0 ['batch_normalization_5[0][0]'] activation_7 (Activation) (None, 35, 35, 64) 0 ['batch_normalization_7[0][0]'] activation_10 (Activation) (None, 35, 35, 96) 0 ['batch_normalization_10[0][0]'] activation_11 (Activation) (None, 35, 35, 32) 0 ['batch_normalization_11[0][0]'] mixed0 (Concatenate) (None, 35, 35, 256) 0 ['activation_5[0][0]', 'activation_7[0][0]', 'activation_10[0][0]', 'activation_11[0][0]'] conv2d_15 (Conv2D) (None, 35, 35, 64) 16384 ['mixed0[0][0]'] batch_normalization_15 (BatchN (None, 35, 35, 64) 192 ['conv2d_15[0][0]'] ormalization) activation_15 (Activation) (None, 35, 35, 64) 0 ['batch_normalization_15[0][0]'] conv2d_13 (Conv2D) (None, 35, 35, 48) 12288 ['mixed0[0][0]'] conv2d_16 (Conv2D) (None, 35, 35, 96) 55296 ['activation_15[0][0]'] batch_normalization_13 (BatchN (None, 35, 35, 48) 144 ['conv2d_13[0][0]'] ormalization) batch_normalization_16 (BatchN (None, 35, 35, 96) 288 ['conv2d_16[0][0]'] ormalization) activation_13 (Activation) (None, 35, 35, 48) 0 ['batch_normalization_13[0][0]'] activation_16 (Activation) (None, 35, 35, 96) 0 ['batch_normalization_16[0][0]'] average_pooling2d_1 (AveragePo (None, 35, 35, 256) 0 ['mixed0[0][0]'] oling2D) conv2d_12 (Conv2D) (None, 35, 35, 64) 16384 ['mixed0[0][0]'] conv2d_14 (Conv2D) (None, 35, 35, 64) 76800 ['activation_13[0][0]'] conv2d_17 (Conv2D) (None, 35, 35, 96) 82944 ['activation_16[0][0]'] conv2d_18 (Conv2D) (None, 35, 35, 64) 16384 ['average_pooling2d_1[0][0]'] batch_normalization_12 (BatchN (None, 35, 35, 64) 192 ['conv2d_12[0][0]'] ormalization) batch_normalization_14 (BatchN (None, 35, 35, 64) 192 ['conv2d_14[0][0]'] ormalization) batch_normalization_17 (BatchN (None, 35, 35, 96) 288 ['conv2d_17[0][0]'] ormalization) batch_normalization_18 (BatchN (None, 35, 35, 64) 192 ['conv2d_18[0][0]'] ormalization) activation_12 (Activation) (None, 35, 35, 64) 0 ['batch_normalization_12[0][0]'] activation_14 (Activation) (None, 35, 35, 64) 0 ['batch_normalization_14[0][0]'] activation_17 (Activation) (None, 35, 35, 96) 0 ['batch_normalization_17[0][0]'] activation_18 (Activation) (None, 35, 35, 64) 0 ['batch_normalization_18[0][0]'] mixed1 (Concatenate) (None, 35, 35, 288) 0 ['activation_12[0][0]', 'activation_14[0][0]', 'activation_17[0][0]', 'activation_18[0][0]'] conv2d_22 (Conv2D) (None, 35, 35, 64) 18432 ['mixed1[0][0]'] batch_normalization_22 (BatchN (None, 35, 35, 64) 192 ['conv2d_22[0][0]'] ormalization) activation_22 (Activation) (None, 35, 35, 64) 0 ['batch_normalization_22[0][0]'] conv2d_20 (Conv2D) (None, 35, 35, 48) 13824 ['mixed1[0][0]'] conv2d_23 (Conv2D) (None, 35, 35, 96) 55296 ['activation_22[0][0]'] batch_normalization_20 (BatchN (None, 35, 35, 48) 144 ['conv2d_20[0][0]'] ormalization) batch_normalization_23 (BatchN (None, 35, 35, 96) 288 ['conv2d_23[0][0]'] ormalization) activation_20 (Activation) (None, 35, 35, 48) 0 ['batch_normalization_20[0][0]'] activation_23 (Activation) (None, 35, 35, 96) 0 ['batch_normalization_23[0][0]'] average_pooling2d_2 (AveragePo (None, 35, 35, 288) 0 ['mixed1[0][0]'] oling2D) conv2d_19 (Conv2D) (None, 35, 35, 64) 18432 ['mixed1[0][0]'] conv2d_21 (Conv2D) (None, 35, 35, 64) 76800 ['activation_20[0][0]'] conv2d_24 (Conv2D) (None, 35, 35, 96) 82944 ['activation_23[0][0]'] conv2d_25 (Conv2D) (None, 35, 35, 64) 18432 ['average_pooling2d_2[0][0]'] batch_normalization_19 (BatchN (None, 35, 35, 64) 192 ['conv2d_19[0][0]'] ormalization) batch_normalization_21 (BatchN (None, 35, 35, 64) 192 ['conv2d_21[0][0]'] ormalization) batch_normalization_24 (BatchN (None, 35, 35, 96) 288 ['conv2d_24[0][0]'] ormalization) batch_normalization_25 (BatchN (None, 35, 35, 64) 192 ['conv2d_25[0][0]'] ormalization) activation_19 (Activation) (None, 35, 35, 64) 0 ['batch_normalization_19[0][0]'] activation_21 (Activation) (None, 35, 35, 64) 0 ['batch_normalization_21[0][0]'] activation_24 (Activation) (None, 35, 35, 96) 0 ['batch_normalization_24[0][0]'] activation_25 (Activation) (None, 35, 35, 64) 0 ['batch_normalization_25[0][0]'] mixed2 (Concatenate) (None, 35, 35, 288) 0 ['activation_19[0][0]', 'activation_21[0][0]', 'activation_24[0][0]', 'activation_25[0][0]'] conv2d_27 (Conv2D) (None, 35, 35, 64) 18432 ['mixed2[0][0]'] batch_normalization_27 (BatchN (None, 35, 35, 64) 192 ['conv2d_27[0][0]'] ormalization) activation_27 (Activation) (None, 35, 35, 64) 0 ['batch_normalization_27[0][0]'] conv2d_28 (Conv2D) (None, 35, 35, 96) 55296 ['activation_27[0][0]'] batch_normalization_28 (BatchN (None, 35, 35, 96) 288 ['conv2d_28[0][0]'] ormalization) activation_28 (Activation) (None, 35, 35, 96) 0 ['batch_normalization_28[0][0]'] conv2d_26 (Conv2D) (None, 17, 17, 384) 995328 ['mixed2[0][0]'] conv2d_29 (Conv2D) (None, 17, 17, 96) 82944 ['activation_28[0][0]'] batch_normalization_26 (BatchN (None, 17, 17, 384) 1152 ['conv2d_26[0][0]'] ormalization) batch_normalization_29 (BatchN (None, 17, 17, 96) 288 ['conv2d_29[0][0]'] ormalization) activation_26 (Activation) (None, 17, 17, 384) 0 ['batch_normalization_26[0][0]'] activation_29 (Activation) (None, 17, 17, 96) 0 ['batch_normalization_29[0][0]'] max_pooling2d_2 (MaxPooling2D) (None, 17, 17, 288) 0 ['mixed2[0][0]'] mixed3 (Concatenate) (None, 17, 17, 768) 0 ['activation_26[0][0]', 'activation_29[0][0]', 'max_pooling2d_2[0][0]'] conv2d_34 (Conv2D) (None, 17, 17, 128) 98304 ['mixed3[0][0]'] batch_normalization_34 (BatchN (None, 17, 17, 128) 384 ['conv2d_34[0][0]'] ormalization) activation_34 (Activation) (None, 17, 17, 128) 0 ['batch_normalization_34[0][0]'] conv2d_35 (Conv2D) (None, 17, 17, 128) 114688 ['activation_34[0][0]'] batch_normalization_35 (BatchN (None, 17, 17, 128) 384 ['conv2d_35[0][0]'] ormalization) activation_35 (Activation) (None, 17, 17, 128) 0 ['batch_normalization_35[0][0]'] conv2d_31 (Conv2D) (None, 17, 17, 128) 98304 ['mixed3[0][0]'] conv2d_36 (Conv2D) (None, 17, 17, 128) 114688 ['activation_35[0][0]'] batch_normalization_31 (BatchN (None, 17, 17, 128) 384 ['conv2d_31[0][0]'] ormalization) batch_normalization_36 (BatchN (None, 17, 17, 128) 384 ['conv2d_36[0][0]'] ormalization) activation_31 (Activation) (None, 17, 17, 128) 0 ['batch_normalization_31[0][0]'] activation_36 (Activation) (None, 17, 17, 128) 0 ['batch_normalization_36[0][0]'] conv2d_32 (Conv2D) (None, 17, 17, 128) 114688 ['activation_31[0][0]'] conv2d_37 (Conv2D) (None, 17, 17, 128) 114688 ['activation_36[0][0]'] batch_normalization_32 (BatchN (None, 17, 17, 128) 384 ['conv2d_32[0][0]'] ormalization) batch_normalization_37 (BatchN (None, 17, 17, 128) 384 ['conv2d_37[0][0]'] ormalization) activation_32 (Activation) (None, 17, 17, 128) 0 ['batch_normalization_32[0][0]'] activation_37 (Activation) (None, 17, 17, 128) 0 ['batch_normalization_37[0][0]'] average_pooling2d_3 (AveragePo (None, 17, 17, 768) 0 ['mixed3[0][0]'] oling2D) conv2d_30 (Conv2D) (None, 17, 17, 192) 147456 ['mixed3[0][0]'] conv2d_33 (Conv2D) (None, 17, 17, 192) 172032 ['activation_32[0][0]'] conv2d_38 (Conv2D) (None, 17, 17, 192) 172032 ['activation_37[0][0]'] conv2d_39 (Conv2D) (None, 17, 17, 192) 147456 ['average_pooling2d_3[0][0]'] batch_normalization_30 (BatchN (None, 17, 17, 192) 576 ['conv2d_30[0][0]'] ormalization) batch_normalization_33 (BatchN (None, 17, 17, 192) 576 ['conv2d_33[0][0]'] ormalization) batch_normalization_38 (BatchN (None, 17, 17, 192) 576 ['conv2d_38[0][0]'] ormalization) batch_normalization_39 (BatchN (None, 17, 17, 192) 576 ['conv2d_39[0][0]'] ormalization) activation_30 (Activation) (None, 17, 17, 192) 0 ['batch_normalization_30[0][0]'] activation_33 (Activation) (None, 17, 17, 192) 0 ['batch_normalization_33[0][0]'] activation_38 (Activation) (None, 17, 17, 192) 0 ['batch_normalization_38[0][0]'] activation_39 (Activation) (None, 17, 17, 192) 0 ['batch_normalization_39[0][0]'] mixed4 (Concatenate) (None, 17, 17, 768) 0 ['activation_30[0][0]', 'activation_33[0][0]', 'activation_38[0][0]', 'activation_39[0][0]'] conv2d_44 (Conv2D) (None, 17, 17, 160) 122880 ['mixed4[0][0]'] batch_normalization_44 (BatchN (None, 17, 17, 160) 480 ['conv2d_44[0][0]'] ormalization) activation_44 (Activation) (None, 17, 17, 160) 0 ['batch_normalization_44[0][0]'] conv2d_45 (Conv2D) (None, 17, 17, 160) 179200 ['activation_44[0][0]'] batch_normalization_45 (BatchN (None, 17, 17, 160) 480 ['conv2d_45[0][0]'] ormalization) activation_45 (Activation) (None, 17, 17, 160) 0 ['batch_normalization_45[0][0]'] conv2d_41 (Conv2D) (None, 17, 17, 160) 122880 ['mixed4[0][0]'] conv2d_46 (Conv2D) (None, 17, 17, 160) 179200 ['activation_45[0][0]'] batch_normalization_41 (BatchN (None, 17, 17, 160) 480 ['conv2d_41[0][0]'] ormalization) batch_normalization_46 (BatchN (None, 17, 17, 160) 480 ['conv2d_46[0][0]'] ormalization) activation_41 (Activation) (None, 17, 17, 160) 0 ['batch_normalization_41[0][0]'] activation_46 (Activation) (None, 17, 17, 160) 0 ['batch_normalization_46[0][0]'] conv2d_42 (Conv2D) (None, 17, 17, 160) 179200 ['activation_41[0][0]'] conv2d_47 (Conv2D) (None, 17, 17, 160) 179200 ['activation_46[0][0]'] batch_normalization_42 (BatchN (None, 17, 17, 160) 480 ['conv2d_42[0][0]'] ormalization) batch_normalization_47 (BatchN (None, 17, 17, 160) 480 ['conv2d_47[0][0]'] ormalization) activation_42 (Activation) (None, 17, 17, 160) 0 ['batch_normalization_42[0][0]'] activation_47 (Activation) (None, 17, 17, 160) 0 ['batch_normalization_47[0][0]'] average_pooling2d_4 (AveragePo (None, 17, 17, 768) 0 ['mixed4[0][0]'] oling2D) conv2d_40 (Conv2D) (None, 17, 17, 192) 147456 ['mixed4[0][0]'] conv2d_43 (Conv2D) (None, 17, 17, 192) 215040 ['activation_42[0][0]'] conv2d_48 (Conv2D) (None, 17, 17, 192) 215040 ['activation_47[0][0]'] conv2d_49 (Conv2D) (None, 17, 17, 192) 147456 ['average_pooling2d_4[0][0]'] batch_normalization_40 (BatchN (None, 17, 17, 192) 576 ['conv2d_40[0][0]'] ormalization) batch_normalization_43 (BatchN (None, 17, 17, 192) 576 ['conv2d_43[0][0]'] ormalization) batch_normalization_48 (BatchN (None, 17, 17, 192) 576 ['conv2d_48[0][0]'] ormalization) batch_normalization_49 (BatchN (None, 17, 17, 192) 576 ['conv2d_49[0][0]'] ormalization) activation_40 (Activation) (None, 17, 17, 192) 0 ['batch_normalization_40[0][0]'] activation_43 (Activation) (None, 17, 17, 192) 0 ['batch_normalization_43[0][0]'] activation_48 (Activation) (None, 17, 17, 192) 0 ['batch_normalization_48[0][0]'] activation_49 (Activation) (None, 17, 17, 192) 0 ['batch_normalization_49[0][0]'] mixed5 (Concatenate) (None, 17, 17, 768) 0 ['activation_40[0][0]', 'activation_43[0][0]', 'activation_48[0][0]', 'activation_49[0][0]'] conv2d_54 (Conv2D) (None, 17, 17, 160) 122880 ['mixed5[0][0]'] batch_normalization_54 (BatchN (None, 17, 17, 160) 480 ['conv2d_54[0][0]'] ormalization) activation_54 (Activation) (None, 17, 17, 160) 0 ['batch_normalization_54[0][0]'] conv2d_55 (Conv2D) (None, 17, 17, 160) 179200 ['activation_54[0][0]'] batch_normalization_55 (BatchN (None, 17, 17, 160) 480 ['conv2d_55[0][0]'] ormalization) activation_55 (Activation) (None, 17, 17, 160) 0 ['batch_normalization_55[0][0]'] conv2d_51 (Conv2D) (None, 17, 17, 160) 122880 ['mixed5[0][0]'] conv2d_56 (Conv2D) (None, 17, 17, 160) 179200 ['activation_55[0][0]'] batch_normalization_51 (BatchN (None, 17, 17, 160) 480 ['conv2d_51[0][0]'] ormalization) batch_normalization_56 (BatchN (None, 17, 17, 160) 480 ['conv2d_56[0][0]'] ormalization) activation_51 (Activation) (None, 17, 17, 160) 0 ['batch_normalization_51[0][0]'] activation_56 (Activation) (None, 17, 17, 160) 0 ['batch_normalization_56[0][0]'] conv2d_52 (Conv2D) (None, 17, 17, 160) 179200 ['activation_51[0][0]'] conv2d_57 (Conv2D) (None, 17, 17, 160) 179200 ['activation_56[0][0]'] batch_normalization_52 (BatchN (None, 17, 17, 160) 480 ['conv2d_52[0][0]'] ormalization) batch_normalization_57 (BatchN (None, 17, 17, 160) 480 ['conv2d_57[0][0]'] ormalization) activation_52 (Activation) (None, 17, 17, 160) 0 ['batch_normalization_52[0][0]'] activation_57 (Activation) (None, 17, 17, 160) 0 ['batch_normalization_57[0][0]'] average_pooling2d_5 (AveragePo (None, 17, 17, 768) 0 ['mixed5[0][0]'] oling2D) conv2d_50 (Conv2D) (None, 17, 17, 192) 147456 ['mixed5[0][0]'] conv2d_53 (Conv2D) (None, 17, 17, 192) 215040 ['activation_52[0][0]'] conv2d_58 (Conv2D) (None, 17, 17, 192) 215040 ['activation_57[0][0]'] conv2d_59 (Conv2D) (None, 17, 17, 192) 147456 ['average_pooling2d_5[0][0]'] batch_normalization_50 (BatchN (None, 17, 17, 192) 576 ['conv2d_50[0][0]'] ormalization) batch_normalization_53 (BatchN (None, 17, 17, 192) 576 ['conv2d_53[0][0]'] ormalization) batch_normalization_58 (BatchN (None, 17, 17, 192) 576 ['conv2d_58[0][0]'] ormalization) batch_normalization_59 (BatchN (None, 17, 17, 192) 576 ['conv2d_59[0][0]'] ormalization) activation_50 (Activation) (None, 17, 17, 192) 0 ['batch_normalization_50[0][0]'] activation_53 (Activation) (None, 17, 17, 192) 0 ['batch_normalization_53[0][0]'] activation_58 (Activation) (None, 17, 17, 192) 0 ['batch_normalization_58[0][0]'] activation_59 (Activation) (None, 17, 17, 192) 0 ['batch_normalization_59[0][0]'] mixed6 (Concatenate) (None, 17, 17, 768) 0 ['activation_50[0][0]', 'activation_53[0][0]', 'activation_58[0][0]', 'activation_59[0][0]'] conv2d_64 (Conv2D) (None, 17, 17, 192) 147456 ['mixed6[0][0]'] batch_normalization_64 (BatchN (None, 17, 17, 192) 576 ['conv2d_64[0][0]'] ormalization) activation_64 (Activation) (None, 17, 17, 192) 0 ['batch_normalization_64[0][0]'] conv2d_65 (Conv2D) (None, 17, 17, 192) 258048 ['activation_64[0][0]'] batch_normalization_65 (BatchN (None, 17, 17, 192) 576 ['conv2d_65[0][0]'] ormalization) activation_65 (Activation) (None, 17, 17, 192) 0 ['batch_normalization_65[0][0]'] conv2d_61 (Conv2D) (None, 17, 17, 192) 147456 ['mixed6[0][0]'] conv2d_66 (Conv2D) (None, 17, 17, 192) 258048 ['activation_65[0][0]'] batch_normalization_61 (BatchN (None, 17, 17, 192) 576 ['conv2d_61[0][0]'] ormalization) batch_normalization_66 (BatchN (None, 17, 17, 192) 576 ['conv2d_66[0][0]'] ormalization) activation_61 (Activation) (None, 17, 17, 192) 0 ['batch_normalization_61[0][0]'] activation_66 (Activation) (None, 17, 17, 192) 0 ['batch_normalization_66[0][0]'] conv2d_62 (Conv2D) (None, 17, 17, 192) 258048 ['activation_61[0][0]'] conv2d_67 (Conv2D) (None, 17, 17, 192) 258048 ['activation_66[0][0]'] batch_normalization_62 (BatchN (None, 17, 17, 192) 576 ['conv2d_62[0][0]'] ormalization) batch_normalization_67 (BatchN (None, 17, 17, 192) 576 ['conv2d_67[0][0]'] ormalization) activation_62 (Activation) (None, 17, 17, 192) 0 ['batch_normalization_62[0][0]'] activation_67 (Activation) (None, 17, 17, 192) 0 ['batch_normalization_67[0][0]'] average_pooling2d_6 (AveragePo (None, 17, 17, 768) 0 ['mixed6[0][0]'] oling2D) conv2d_60 (Conv2D) (None, 17, 17, 192) 147456 ['mixed6[0][0]'] conv2d_63 (Conv2D) (None, 17, 17, 192) 258048 ['activation_62[0][0]'] conv2d_68 (Conv2D) (None, 17, 17, 192) 258048 ['activation_67[0][0]'] conv2d_69 (Conv2D) (None, 17, 17, 192) 147456 ['average_pooling2d_6[0][0]'] batch_normalization_60 (BatchN (None, 17, 17, 192) 576 ['conv2d_60[0][0]'] ormalization) batch_normalization_63 (BatchN (None, 17, 17, 192) 576 ['conv2d_63[0][0]'] ormalization) batch_normalization_68 (BatchN (None, 17, 17, 192) 576 ['conv2d_68[0][0]'] ormalization) batch_normalization_69 (BatchN (None, 17, 17, 192) 576 ['conv2d_69[0][0]'] ormalization) activation_60 (Activation) (None, 17, 17, 192) 0 ['batch_normalization_60[0][0]'] activation_63 (Activation) (None, 17, 17, 192) 0 ['batch_normalization_63[0][0]'] activation_68 (Activation) (None, 17, 17, 192) 0 ['batch_normalization_68[0][0]'] activation_69 (Activation) (None, 17, 17, 192) 0 ['batch_normalization_69[0][0]'] mixed7 (Concatenate) (None, 17, 17, 768) 0 ['activation_60[0][0]', 'activation_63[0][0]', 'activation_68[0][0]', 'activation_69[0][0]'] conv2d_72 (Conv2D) (None, 17, 17, 192) 147456 ['mixed7[0][0]'] batch_normalization_72 (BatchN (None, 17, 17, 192) 576 ['conv2d_72[0][0]'] ormalization) activation_72 (Activation) (None, 17, 17, 192) 0 ['batch_normalization_72[0][0]'] conv2d_73 (Conv2D) (None, 17, 17, 192) 258048 ['activation_72[0][0]'] batch_normalization_73 (BatchN (None, 17, 17, 192) 576 ['conv2d_73[0][0]'] ormalization) activation_73 (Activation) (None, 17, 17, 192) 0 ['batch_normalization_73[0][0]'] conv2d_70 (Conv2D) (None, 17, 17, 192) 147456 ['mixed7[0][0]'] conv2d_74 (Conv2D) (None, 17, 17, 192) 258048 ['activation_73[0][0]'] batch_normalization_70 (BatchN (None, 17, 17, 192) 576 ['conv2d_70[0][0]'] ormalization) batch_normalization_74 (BatchN (None, 17, 17, 192) 576 ['conv2d_74[0][0]'] ormalization) activation_70 (Activation) (None, 17, 17, 192) 0 ['batch_normalization_70[0][0]'] activation_74 (Activation) (None, 17, 17, 192) 0 ['batch_normalization_74[0][0]'] conv2d_71 (Conv2D) (None, 8, 8, 320) 552960 ['activation_70[0][0]'] conv2d_75 (Conv2D) (None, 8, 8, 192) 331776 ['activation_74[0][0]'] batch_normalization_71 (BatchN (None, 8, 8, 320) 960 ['conv2d_71[0][0]'] ormalization) batch_normalization_75 (BatchN (None, 8, 8, 192) 576 ['conv2d_75[0][0]'] ormalization) activation_71 (Activation) (None, 8, 8, 320) 0 ['batch_normalization_71[0][0]'] activation_75 (Activation) (None, 8, 8, 192) 0 ['batch_normalization_75[0][0]'] max_pooling2d_3 (MaxPooling2D) (None, 8, 8, 768) 0 ['mixed7[0][0]'] mixed8 (Concatenate) (None, 8, 8, 1280) 0 ['activation_71[0][0]', 'activation_75[0][0]', 'max_pooling2d_3[0][0]'] conv2d_80 (Conv2D) (None, 8, 8, 448) 573440 ['mixed8[0][0]'] batch_normalization_80 (BatchN (None, 8, 8, 448) 1344 ['conv2d_80[0][0]'] ormalization) activation_80 (Activation) (None, 8, 8, 448) 0 ['batch_normalization_80[0][0]'] conv2d_77 (Conv2D) (None, 8, 8, 384) 491520 ['mixed8[0][0]'] conv2d_81 (Conv2D) (None, 8, 8, 384) 1548288 ['activation_80[0][0]'] batch_normalization_77 (BatchN (None, 8, 8, 384) 1152 ['conv2d_77[0][0]'] ormalization) batch_normalization_81 (BatchN (None, 8, 8, 384) 1152 ['conv2d_81[0][0]'] ormalization) activation_77 (Activation) (None, 8, 8, 384) 0 ['batch_normalization_77[0][0]'] activation_81 (Activation) (None, 8, 8, 384) 0 ['batch_normalization_81[0][0]'] conv2d_78 (Conv2D) (None, 8, 8, 384) 442368 ['activation_77[0][0]'] conv2d_79 (Conv2D) (None, 8, 8, 384) 442368 ['activation_77[0][0]'] conv2d_82 (Conv2D) (None, 8, 8, 384) 442368 ['activation_81[0][0]'] conv2d_83 (Conv2D) (None, 8, 8, 384) 442368 ['activation_81[0][0]'] average_pooling2d_7 (AveragePo (None, 8, 8, 1280) 0 ['mixed8[0][0]'] oling2D) conv2d_76 (Conv2D) (None, 8, 8, 320) 409600 ['mixed8[0][0]'] batch_normalization_78 (BatchN (None, 8, 8, 384) 1152 ['conv2d_78[0][0]'] ormalization) batch_normalization_79 (BatchN (None, 8, 8, 384) 1152 ['conv2d_79[0][0]'] ormalization) batch_normalization_82 (BatchN (None, 8, 8, 384) 1152 ['conv2d_82[0][0]'] ormalization) batch_normalization_83 (BatchN (None, 8, 8, 384) 1152 ['conv2d_83[0][0]'] ormalization) conv2d_84 (Conv2D) (None, 8, 8, 192) 245760 ['average_pooling2d_7[0][0]'] batch_normalization_76 (BatchN (None, 8, 8, 320) 960 ['conv2d_76[0][0]'] ormalization) activation_78 (Activation) (None, 8, 8, 384) 0 ['batch_normalization_78[0][0]'] activation_79 (Activation) (None, 8, 8, 384) 0 ['batch_normalization_79[0][0]'] activation_82 (Activation) (None, 8, 8, 384) 0 ['batch_normalization_82[0][0]'] activation_83 (Activation) (None, 8, 8, 384) 0 ['batch_normalization_83[0][0]'] batch_normalization_84 (BatchN (None, 8, 8, 192) 576 ['conv2d_84[0][0]'] ormalization) activation_76 (Activation) (None, 8, 8, 320) 0 ['batch_normalization_76[0][0]'] mixed9_0 (Concatenate) (None, 8, 8, 768) 0 ['activation_78[0][0]', 'activation_79[0][0]'] concatenate (Concatenate) (None, 8, 8, 768) 0 ['activation_82[0][0]', 'activation_83[0][0]'] activation_84 (Activation) (None, 8, 8, 192) 0 ['batch_normalization_84[0][0]'] mixed9 (Concatenate) (None, 8, 8, 2048) 0 ['activation_76[0][0]', 'mixed9_0[0][0]', 'concatenate[0][0]', 'activation_84[0][0]'] conv2d_89 (Conv2D) (None, 8, 8, 448) 917504 ['mixed9[0][0]'] batch_normalization_89 (BatchN (None, 8, 8, 448) 1344 ['conv2d_89[0][0]'] ormalization) activation_89 (Activation) (None, 8, 8, 448) 0 ['batch_normalization_89[0][0]'] conv2d_86 (Conv2D) (None, 8, 8, 384) 786432 ['mixed9[0][0]'] conv2d_90 (Conv2D) (None, 8, 8, 384) 1548288 ['activation_89[0][0]'] batch_normalization_86 (BatchN (None, 8, 8, 384) 1152 ['conv2d_86[0][0]'] ormalization) batch_normalization_90 (BatchN (None, 8, 8, 384) 1152 ['conv2d_90[0][0]'] ormalization) activation_86 (Activation) (None, 8, 8, 384) 0 ['batch_normalization_86[0][0]'] activation_90 (Activation) (None, 8, 8, 384) 0 ['batch_normalization_90[0][0]'] conv2d_87 (Conv2D) (None, 8, 8, 384) 442368 ['activation_86[0][0]'] conv2d_88 (Conv2D) (None, 8, 8, 384) 442368 ['activation_86[0][0]'] conv2d_91 (Conv2D) (None, 8, 8, 384) 442368 ['activation_90[0][0]'] conv2d_92 (Conv2D) (None, 8, 8, 384) 442368 ['activation_90[0][0]'] average_pooling2d_8 (AveragePo (None, 8, 8, 2048) 0 ['mixed9[0][0]'] oling2D) conv2d_85 (Conv2D) (None, 8, 8, 320) 655360 ['mixed9[0][0]'] batch_normalization_87 (BatchN (None, 8, 8, 384) 1152 ['conv2d_87[0][0]'] ormalization) batch_normalization_88 (BatchN (None, 8, 8, 384) 1152 ['conv2d_88[0][0]'] ormalization) batch_normalization_91 (BatchN (None, 8, 8, 384) 1152 ['conv2d_91[0][0]'] ormalization) batch_normalization_92 (BatchN (None, 8, 8, 384) 1152 ['conv2d_92[0][0]'] ormalization) conv2d_93 (Conv2D) (None, 8, 8, 192) 393216 ['average_pooling2d_8[0][0]'] batch_normalization_85 (BatchN (None, 8, 8, 320) 960 ['conv2d_85[0][0]'] ormalization) activation_87 (Activation) (None, 8, 8, 384) 0 ['batch_normalization_87[0][0]'] activation_88 (Activation) (None, 8, 8, 384) 0 ['batch_normalization_88[0][0]'] activation_91 (Activation) (None, 8, 8, 384) 0 ['batch_normalization_91[0][0]'] activation_92 (Activation) (None, 8, 8, 384) 0 ['batch_normalization_92[0][0]'] batch_normalization_93 (BatchN (None, 8, 8, 192) 576 ['conv2d_93[0][0]'] ormalization) activation_85 (Activation) (None, 8, 8, 320) 0 ['batch_normalization_85[0][0]'] mixed9_1 (Concatenate) (None, 8, 8, 768) 0 ['activation_87[0][0]', 'activation_88[0][0]'] concatenate_1 (Concatenate) (None, 8, 8, 768) 0 ['activation_91[0][0]', 'activation_92[0][0]'] activation_93 (Activation) (None, 8, 8, 192) 0 ['batch_normalization_93[0][0]'] mixed10 (Concatenate) (None, 8, 8, 2048) 0 ['activation_85[0][0]', 'mixed9_1[0][0]', 'concatenate_1[0][0]', 'activation_93[0][0]'] avg_pool (GlobalAveragePooling (None, 2048) 0 ['mixed10[0][0]'] 2D) predictions (Dense) (None, 1000) 2049000 ['avg_pool[0][0]'] ================================================================================================== Total params: 23,851,784 Trainable params: 23,817,352 Non-trainable params: 34,432 __________________________________________________________________________________________________
As you can see, it's a very deep network and compared to VGG-19, it's harder to choose which layers to choose to extract features from.
Notice that the Conv2D layers are named from conv2d, conv2d_1 ... conv2d_93, for a total of 94 conv2d layers. So the second conv2D layer is named conv2d_1.
We choose the following:
# choose the content layer and put in a list
content_layers = ['conv2d_93']
# choose the five style layers of interest
style_layers = ['conv2d',
'conv2d_1',
'conv2d_2',
'conv2d_3',
'conv2d_4']
# combine the content and style layers into one list
content_and_style_layers = style_layers + content_layers
# count the number of content layers and style layers.
NUM_CONTENT_LAYERS = len(content_layers)
NUM_STYLE_LAYERS = len(style_layers)
We can now setup our model to output the selected layers.
def inception_model(layer_names):
""" Creates a inception model that returns a list of intermediate output values.
args:
layer_names: a list of strings, representing the names of the desired content and style layers
returns:
A model that takes the regular inception v3 input and outputs just the content and style layers.
"""
# Load InceptionV3 with the imagenet weights and **without** the 3 fully-connected layers at the top of the network
inception = tf.keras.applications.inception_v3.InceptionV3(include_top = False, weights= 'imagenet')
# Freeze the weights of the model's layers (make them not trainable)
inception.trainable = False
# Create a list of layer objects that are specified by layer_names
output_layers = [inception.get_layer(name).output for name in layer_names]
# Create the model that outputs the content and style layers
model = tf.keras.models.Model(inputs = inception.input, outputs = output_layers)
# return the model
return model
Create an instance of the content and style model using the function that you just defined
K.clear_session()
inception = inception_model(content_and_style_layers)
inception
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/inception_v3/inception_v3_weights_tf_dim_ordering_tf_kernels_notop.h5 87916544/87910968 [==============================] - 0s 0us/step 87924736/87910968 [==============================] - 0s 0us/step
<keras.engine.functional.Functional at 0x7fe028051e90>
The style loss is the average of the squared differences between the features and targets.
def get_style_loss(features, targets):
"""Expects two images of dimension h, w, c
Args:
features: tensor with shape: (height, width, channels)
targets: tensor with shape: (height, width, channels)
Returns:
style loss (scalar)
"""
# Calculate the style loss
style_loss = tf.reduce_mean(tf.square(features - targets))
return style_loss
Calculate the sum of the squared error between the features and targets, then multiply by a scaling factor (0.5).
def get_content_loss(features, targets):
"""Expects two images of dimension h, w, c
Args:
features: tensor with shape: (height, width, channels)
targets: tensor with shape: (height, width, channels)
Returns:
content loss (scalar)
"""
# get the sum of the squared error multiplied by a scaling factor
content_loss = 0.5 * tf.reduce_sum(tf.square(features - targets))
return content_loss
def gram_matrix(input_tensor):
""" Calculates the gram matrix and divides by the number of locations
Args:
input_tensor: tensor of shape (batch, height, width, channels)
Returns:
scaled_gram: gram matrix divided by the number of locations
"""
# calculate the gram matrix of the input tensor
gram = tf.linalg.einsum('bijc,bijd->bcd', input_tensor, input_tensor)
# get the height and width of the input tensor
input_shape = tf.shape(input_tensor)
height = input_shape[1]
width = input_shape[2]
# get the number of locations (height times width), and cast it as a tf.float32
num_locations = tf.cast(height * width, tf.float32)
# scale the gram matrix by dividing by the number of locations
scaled_gram = gram / num_locations
return scaled_gram
Given the style image as input, you'll get the style features of the inception model that you just created using inception_model().
We'll first preprocess the image using the preprocess_image function.
We'll then get the outputs of the model.
From the outputs, we just get the style feature layers and not the content feature layer.
tmp_layer_list = [layer.output for layer in inception.layers]
tmp_layer_list
[<KerasTensor: shape=(None, None, None, 3) dtype=float32 (created by layer 'input_1')>, <KerasTensor: shape=(None, None, None, 32) dtype=float32 (created by layer 'conv2d')>, <KerasTensor: shape=(None, None, None, 32) dtype=float32 (created by layer 'batch_normalization')>, <KerasTensor: shape=(None, None, None, 32) dtype=float32 (created by layer 'activation')>, <KerasTensor: shape=(None, None, None, 32) dtype=float32 (created by layer 'conv2d_1')>, <KerasTensor: shape=(None, None, None, 32) dtype=float32 (created by layer 'batch_normalization_1')>, <KerasTensor: shape=(None, None, None, 32) dtype=float32 (created by layer 'activation_1')>, <KerasTensor: shape=(None, None, None, 64) dtype=float32 (created by layer 'conv2d_2')>, <KerasTensor: shape=(None, None, None, 64) dtype=float32 (created by layer 'batch_normalization_2')>, <KerasTensor: shape=(None, None, None, 64) dtype=float32 (created by layer 'activation_2')>, <KerasTensor: shape=(None, None, None, 64) dtype=float32 (created by layer 'max_pooling2d')>, <KerasTensor: shape=(None, None, None, 80) dtype=float32 (created by layer 'conv2d_3')>, <KerasTensor: shape=(None, None, None, 80) dtype=float32 (created by layer 'batch_normalization_3')>, <KerasTensor: shape=(None, None, None, 80) dtype=float32 (created by layer 'activation_3')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'conv2d_4')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'batch_normalization_4')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'activation_4')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'max_pooling2d_1')>, <KerasTensor: shape=(None, None, None, 64) dtype=float32 (created by layer 'conv2d_8')>, <KerasTensor: shape=(None, None, None, 64) dtype=float32 (created by layer 'batch_normalization_8')>, <KerasTensor: shape=(None, None, None, 64) dtype=float32 (created by layer 'activation_8')>, <KerasTensor: shape=(None, None, None, 48) dtype=float32 (created by layer 'conv2d_6')>, <KerasTensor: shape=(None, None, None, 96) dtype=float32 (created by layer 'conv2d_9')>, <KerasTensor: shape=(None, None, None, 48) dtype=float32 (created by layer 'batch_normalization_6')>, <KerasTensor: shape=(None, None, None, 96) dtype=float32 (created by layer 'batch_normalization_9')>, <KerasTensor: shape=(None, None, None, 48) dtype=float32 (created by layer 'activation_6')>, <KerasTensor: shape=(None, None, None, 96) dtype=float32 (created by layer 'activation_9')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'average_pooling2d')>, <KerasTensor: shape=(None, None, None, 64) dtype=float32 (created by layer 'conv2d_5')>, <KerasTensor: shape=(None, None, None, 64) dtype=float32 (created by layer 'conv2d_7')>, <KerasTensor: shape=(None, None, None, 96) dtype=float32 (created by layer 'conv2d_10')>, <KerasTensor: shape=(None, None, None, 32) dtype=float32 (created by layer 'conv2d_11')>, <KerasTensor: shape=(None, None, None, 64) dtype=float32 (created by layer 'batch_normalization_5')>, <KerasTensor: shape=(None, None, None, 64) dtype=float32 (created by layer 'batch_normalization_7')>, <KerasTensor: shape=(None, None, None, 96) dtype=float32 (created by layer 'batch_normalization_10')>, <KerasTensor: shape=(None, None, None, 32) dtype=float32 (created by layer 'batch_normalization_11')>, <KerasTensor: shape=(None, None, None, 64) dtype=float32 (created by layer 'activation_5')>, <KerasTensor: shape=(None, None, None, 64) dtype=float32 (created by layer 'activation_7')>, <KerasTensor: shape=(None, None, None, 96) dtype=float32 (created by layer 'activation_10')>, <KerasTensor: shape=(None, None, None, 32) dtype=float32 (created by layer 'activation_11')>, <KerasTensor: shape=(None, None, None, 256) dtype=float32 (created by layer 'mixed0')>, <KerasTensor: shape=(None, None, None, 64) dtype=float32 (created by layer 'conv2d_15')>, <KerasTensor: shape=(None, None, None, 64) dtype=float32 (created by layer 'batch_normalization_15')>, <KerasTensor: shape=(None, None, None, 64) dtype=float32 (created by layer 'activation_15')>, <KerasTensor: shape=(None, None, None, 48) dtype=float32 (created by layer 'conv2d_13')>, <KerasTensor: shape=(None, None, None, 96) dtype=float32 (created by layer 'conv2d_16')>, <KerasTensor: shape=(None, None, None, 48) dtype=float32 (created by layer 'batch_normalization_13')>, <KerasTensor: shape=(None, None, None, 96) dtype=float32 (created by layer 'batch_normalization_16')>, <KerasTensor: shape=(None, None, None, 48) dtype=float32 (created by layer 'activation_13')>, <KerasTensor: shape=(None, None, None, 96) dtype=float32 (created by layer 'activation_16')>, <KerasTensor: shape=(None, None, None, 256) dtype=float32 (created by layer 'average_pooling2d_1')>, <KerasTensor: shape=(None, None, None, 64) dtype=float32 (created by layer 'conv2d_12')>, <KerasTensor: shape=(None, None, None, 64) dtype=float32 (created by layer 'conv2d_14')>, <KerasTensor: shape=(None, None, None, 96) dtype=float32 (created by layer 'conv2d_17')>, <KerasTensor: shape=(None, None, None, 64) dtype=float32 (created by layer 'conv2d_18')>, <KerasTensor: shape=(None, None, None, 64) dtype=float32 (created by layer 'batch_normalization_12')>, <KerasTensor: shape=(None, None, None, 64) dtype=float32 (created by layer 'batch_normalization_14')>, <KerasTensor: shape=(None, None, None, 96) dtype=float32 (created by layer 'batch_normalization_17')>, <KerasTensor: shape=(None, None, None, 64) dtype=float32 (created by layer 'batch_normalization_18')>, <KerasTensor: shape=(None, None, None, 64) dtype=float32 (created by layer 'activation_12')>, <KerasTensor: shape=(None, None, None, 64) dtype=float32 (created by layer 'activation_14')>, <KerasTensor: shape=(None, None, None, 96) dtype=float32 (created by layer 'activation_17')>, <KerasTensor: shape=(None, None, None, 64) dtype=float32 (created by layer 'activation_18')>, <KerasTensor: shape=(None, None, None, 288) dtype=float32 (created by layer 'mixed1')>, <KerasTensor: shape=(None, None, None, 64) dtype=float32 (created by layer 'conv2d_22')>, <KerasTensor: shape=(None, None, None, 64) dtype=float32 (created by layer 'batch_normalization_22')>, <KerasTensor: shape=(None, None, None, 64) dtype=float32 (created by layer 'activation_22')>, <KerasTensor: shape=(None, None, None, 48) dtype=float32 (created by layer 'conv2d_20')>, <KerasTensor: shape=(None, None, None, 96) dtype=float32 (created by layer 'conv2d_23')>, <KerasTensor: shape=(None, None, None, 48) dtype=float32 (created by layer 'batch_normalization_20')>, <KerasTensor: shape=(None, None, None, 96) dtype=float32 (created by layer 'batch_normalization_23')>, <KerasTensor: shape=(None, None, None, 48) dtype=float32 (created by layer 'activation_20')>, <KerasTensor: shape=(None, None, None, 96) dtype=float32 (created by layer 'activation_23')>, <KerasTensor: shape=(None, None, None, 288) dtype=float32 (created by layer 'average_pooling2d_2')>, <KerasTensor: shape=(None, None, None, 64) dtype=float32 (created by layer 'conv2d_19')>, <KerasTensor: shape=(None, None, None, 64) dtype=float32 (created by layer 'conv2d_21')>, <KerasTensor: shape=(None, None, None, 96) dtype=float32 (created by layer 'conv2d_24')>, <KerasTensor: shape=(None, None, None, 64) dtype=float32 (created by layer 'conv2d_25')>, <KerasTensor: shape=(None, None, None, 64) dtype=float32 (created by layer 'batch_normalization_19')>, <KerasTensor: shape=(None, None, None, 64) dtype=float32 (created by layer 'batch_normalization_21')>, <KerasTensor: shape=(None, None, None, 96) dtype=float32 (created by layer 'batch_normalization_24')>, <KerasTensor: shape=(None, None, None, 64) dtype=float32 (created by layer 'batch_normalization_25')>, <KerasTensor: shape=(None, None, None, 64) dtype=float32 (created by layer 'activation_19')>, <KerasTensor: shape=(None, None, None, 64) dtype=float32 (created by layer 'activation_21')>, <KerasTensor: shape=(None, None, None, 96) dtype=float32 (created by layer 'activation_24')>, <KerasTensor: shape=(None, None, None, 64) dtype=float32 (created by layer 'activation_25')>, <KerasTensor: shape=(None, None, None, 288) dtype=float32 (created by layer 'mixed2')>, <KerasTensor: shape=(None, None, None, 64) dtype=float32 (created by layer 'conv2d_27')>, <KerasTensor: shape=(None, None, None, 64) dtype=float32 (created by layer 'batch_normalization_27')>, <KerasTensor: shape=(None, None, None, 64) dtype=float32 (created by layer 'activation_27')>, <KerasTensor: shape=(None, None, None, 96) dtype=float32 (created by layer 'conv2d_28')>, <KerasTensor: shape=(None, None, None, 96) dtype=float32 (created by layer 'batch_normalization_28')>, <KerasTensor: shape=(None, None, None, 96) dtype=float32 (created by layer 'activation_28')>, <KerasTensor: shape=(None, None, None, 384) dtype=float32 (created by layer 'conv2d_26')>, <KerasTensor: shape=(None, None, None, 96) dtype=float32 (created by layer 'conv2d_29')>, <KerasTensor: shape=(None, None, None, 384) dtype=float32 (created by layer 'batch_normalization_26')>, <KerasTensor: shape=(None, None, None, 96) dtype=float32 (created by layer 'batch_normalization_29')>, <KerasTensor: shape=(None, None, None, 384) dtype=float32 (created by layer 'activation_26')>, <KerasTensor: shape=(None, None, None, 96) dtype=float32 (created by layer 'activation_29')>, <KerasTensor: shape=(None, None, None, 288) dtype=float32 (created by layer 'max_pooling2d_2')>, <KerasTensor: shape=(None, None, None, 768) dtype=float32 (created by layer 'mixed3')>, <KerasTensor: shape=(None, None, None, 128) dtype=float32 (created by layer 'conv2d_34')>, <KerasTensor: shape=(None, None, None, 128) dtype=float32 (created by layer 'batch_normalization_34')>, <KerasTensor: shape=(None, None, None, 128) dtype=float32 (created by layer 'activation_34')>, <KerasTensor: shape=(None, None, None, 128) dtype=float32 (created by layer 'conv2d_35')>, <KerasTensor: shape=(None, None, None, 128) dtype=float32 (created by layer 'batch_normalization_35')>, <KerasTensor: shape=(None, None, None, 128) dtype=float32 (created by layer 'activation_35')>, <KerasTensor: shape=(None, None, None, 128) dtype=float32 (created by layer 'conv2d_31')>, <KerasTensor: shape=(None, None, None, 128) dtype=float32 (created by layer 'conv2d_36')>, <KerasTensor: shape=(None, None, None, 128) dtype=float32 (created by layer 'batch_normalization_31')>, <KerasTensor: shape=(None, None, None, 128) dtype=float32 (created by layer 'batch_normalization_36')>, <KerasTensor: shape=(None, None, None, 128) dtype=float32 (created by layer 'activation_31')>, <KerasTensor: shape=(None, None, None, 128) dtype=float32 (created by layer 'activation_36')>, <KerasTensor: shape=(None, None, None, 128) dtype=float32 (created by layer 'conv2d_32')>, <KerasTensor: shape=(None, None, None, 128) dtype=float32 (created by layer 'conv2d_37')>, <KerasTensor: shape=(None, None, None, 128) dtype=float32 (created by layer 'batch_normalization_32')>, <KerasTensor: shape=(None, None, None, 128) dtype=float32 (created by layer 'batch_normalization_37')>, <KerasTensor: shape=(None, None, None, 128) dtype=float32 (created by layer 'activation_32')>, <KerasTensor: shape=(None, None, None, 128) dtype=float32 (created by layer 'activation_37')>, <KerasTensor: shape=(None, None, None, 768) dtype=float32 (created by layer 'average_pooling2d_3')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'conv2d_30')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'conv2d_33')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'conv2d_38')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'conv2d_39')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'batch_normalization_30')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'batch_normalization_33')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'batch_normalization_38')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'batch_normalization_39')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'activation_30')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'activation_33')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'activation_38')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'activation_39')>, <KerasTensor: shape=(None, None, None, 768) dtype=float32 (created by layer 'mixed4')>, <KerasTensor: shape=(None, None, None, 160) dtype=float32 (created by layer 'conv2d_44')>, <KerasTensor: shape=(None, None, None, 160) dtype=float32 (created by layer 'batch_normalization_44')>, <KerasTensor: shape=(None, None, None, 160) dtype=float32 (created by layer 'activation_44')>, <KerasTensor: shape=(None, None, None, 160) dtype=float32 (created by layer 'conv2d_45')>, <KerasTensor: shape=(None, None, None, 160) dtype=float32 (created by layer 'batch_normalization_45')>, <KerasTensor: shape=(None, None, None, 160) dtype=float32 (created by layer 'activation_45')>, <KerasTensor: shape=(None, None, None, 160) dtype=float32 (created by layer 'conv2d_41')>, <KerasTensor: shape=(None, None, None, 160) dtype=float32 (created by layer 'conv2d_46')>, <KerasTensor: shape=(None, None, None, 160) dtype=float32 (created by layer 'batch_normalization_41')>, <KerasTensor: shape=(None, None, None, 160) dtype=float32 (created by layer 'batch_normalization_46')>, <KerasTensor: shape=(None, None, None, 160) dtype=float32 (created by layer 'activation_41')>, <KerasTensor: shape=(None, None, None, 160) dtype=float32 (created by layer 'activation_46')>, <KerasTensor: shape=(None, None, None, 160) dtype=float32 (created by layer 'conv2d_42')>, <KerasTensor: shape=(None, None, None, 160) dtype=float32 (created by layer 'conv2d_47')>, <KerasTensor: shape=(None, None, None, 160) dtype=float32 (created by layer 'batch_normalization_42')>, <KerasTensor: shape=(None, None, None, 160) dtype=float32 (created by layer 'batch_normalization_47')>, <KerasTensor: shape=(None, None, None, 160) dtype=float32 (created by layer 'activation_42')>, <KerasTensor: shape=(None, None, None, 160) dtype=float32 (created by layer 'activation_47')>, <KerasTensor: shape=(None, None, None, 768) dtype=float32 (created by layer 'average_pooling2d_4')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'conv2d_40')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'conv2d_43')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'conv2d_48')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'conv2d_49')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'batch_normalization_40')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'batch_normalization_43')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'batch_normalization_48')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'batch_normalization_49')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'activation_40')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'activation_43')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'activation_48')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'activation_49')>, <KerasTensor: shape=(None, None, None, 768) dtype=float32 (created by layer 'mixed5')>, <KerasTensor: shape=(None, None, None, 160) dtype=float32 (created by layer 'conv2d_54')>, <KerasTensor: shape=(None, None, None, 160) dtype=float32 (created by layer 'batch_normalization_54')>, <KerasTensor: shape=(None, None, None, 160) dtype=float32 (created by layer 'activation_54')>, <KerasTensor: shape=(None, None, None, 160) dtype=float32 (created by layer 'conv2d_55')>, <KerasTensor: shape=(None, None, None, 160) dtype=float32 (created by layer 'batch_normalization_55')>, <KerasTensor: shape=(None, None, None, 160) dtype=float32 (created by layer 'activation_55')>, <KerasTensor: shape=(None, None, None, 160) dtype=float32 (created by layer 'conv2d_51')>, <KerasTensor: shape=(None, None, None, 160) dtype=float32 (created by layer 'conv2d_56')>, <KerasTensor: shape=(None, None, None, 160) dtype=float32 (created by layer 'batch_normalization_51')>, <KerasTensor: shape=(None, None, None, 160) dtype=float32 (created by layer 'batch_normalization_56')>, <KerasTensor: shape=(None, None, None, 160) dtype=float32 (created by layer 'activation_51')>, <KerasTensor: shape=(None, None, None, 160) dtype=float32 (created by layer 'activation_56')>, <KerasTensor: shape=(None, None, None, 160) dtype=float32 (created by layer 'conv2d_52')>, <KerasTensor: shape=(None, None, None, 160) dtype=float32 (created by layer 'conv2d_57')>, <KerasTensor: shape=(None, None, None, 160) dtype=float32 (created by layer 'batch_normalization_52')>, <KerasTensor: shape=(None, None, None, 160) dtype=float32 (created by layer 'batch_normalization_57')>, <KerasTensor: shape=(None, None, None, 160) dtype=float32 (created by layer 'activation_52')>, <KerasTensor: shape=(None, None, None, 160) dtype=float32 (created by layer 'activation_57')>, <KerasTensor: shape=(None, None, None, 768) dtype=float32 (created by layer 'average_pooling2d_5')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'conv2d_50')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'conv2d_53')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'conv2d_58')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'conv2d_59')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'batch_normalization_50')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'batch_normalization_53')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'batch_normalization_58')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'batch_normalization_59')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'activation_50')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'activation_53')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'activation_58')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'activation_59')>, <KerasTensor: shape=(None, None, None, 768) dtype=float32 (created by layer 'mixed6')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'conv2d_64')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'batch_normalization_64')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'activation_64')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'conv2d_65')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'batch_normalization_65')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'activation_65')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'conv2d_61')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'conv2d_66')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'batch_normalization_61')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'batch_normalization_66')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'activation_61')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'activation_66')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'conv2d_62')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'conv2d_67')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'batch_normalization_62')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'batch_normalization_67')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'activation_62')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'activation_67')>, <KerasTensor: shape=(None, None, None, 768) dtype=float32 (created by layer 'average_pooling2d_6')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'conv2d_60')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'conv2d_63')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'conv2d_68')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'conv2d_69')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'batch_normalization_60')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'batch_normalization_63')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'batch_normalization_68')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'batch_normalization_69')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'activation_60')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'activation_63')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'activation_68')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'activation_69')>, <KerasTensor: shape=(None, None, None, 768) dtype=float32 (created by layer 'mixed7')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'conv2d_72')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'batch_normalization_72')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'activation_72')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'conv2d_73')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'batch_normalization_73')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'activation_73')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'conv2d_70')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'conv2d_74')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'batch_normalization_70')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'batch_normalization_74')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'activation_70')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'activation_74')>, <KerasTensor: shape=(None, None, None, 320) dtype=float32 (created by layer 'conv2d_71')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'conv2d_75')>, <KerasTensor: shape=(None, None, None, 320) dtype=float32 (created by layer 'batch_normalization_71')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'batch_normalization_75')>, <KerasTensor: shape=(None, None, None, 320) dtype=float32 (created by layer 'activation_71')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'activation_75')>, <KerasTensor: shape=(None, None, None, 768) dtype=float32 (created by layer 'max_pooling2d_3')>, <KerasTensor: shape=(None, None, None, 1280) dtype=float32 (created by layer 'mixed8')>, <KerasTensor: shape=(None, None, None, 448) dtype=float32 (created by layer 'conv2d_80')>, <KerasTensor: shape=(None, None, None, 448) dtype=float32 (created by layer 'batch_normalization_80')>, <KerasTensor: shape=(None, None, None, 448) dtype=float32 (created by layer 'activation_80')>, <KerasTensor: shape=(None, None, None, 384) dtype=float32 (created by layer 'conv2d_77')>, <KerasTensor: shape=(None, None, None, 384) dtype=float32 (created by layer 'conv2d_81')>, <KerasTensor: shape=(None, None, None, 384) dtype=float32 (created by layer 'batch_normalization_77')>, <KerasTensor: shape=(None, None, None, 384) dtype=float32 (created by layer 'batch_normalization_81')>, <KerasTensor: shape=(None, None, None, 384) dtype=float32 (created by layer 'activation_77')>, <KerasTensor: shape=(None, None, None, 384) dtype=float32 (created by layer 'activation_81')>, <KerasTensor: shape=(None, None, None, 384) dtype=float32 (created by layer 'conv2d_78')>, <KerasTensor: shape=(None, None, None, 384) dtype=float32 (created by layer 'conv2d_79')>, <KerasTensor: shape=(None, None, None, 384) dtype=float32 (created by layer 'conv2d_82')>, <KerasTensor: shape=(None, None, None, 384) dtype=float32 (created by layer 'conv2d_83')>, <KerasTensor: shape=(None, None, None, 1280) dtype=float32 (created by layer 'average_pooling2d_7')>, <KerasTensor: shape=(None, None, None, 320) dtype=float32 (created by layer 'conv2d_76')>, <KerasTensor: shape=(None, None, None, 384) dtype=float32 (created by layer 'batch_normalization_78')>, <KerasTensor: shape=(None, None, None, 384) dtype=float32 (created by layer 'batch_normalization_79')>, <KerasTensor: shape=(None, None, None, 384) dtype=float32 (created by layer 'batch_normalization_82')>, <KerasTensor: shape=(None, None, None, 384) dtype=float32 (created by layer 'batch_normalization_83')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'conv2d_84')>, <KerasTensor: shape=(None, None, None, 320) dtype=float32 (created by layer 'batch_normalization_76')>, <KerasTensor: shape=(None, None, None, 384) dtype=float32 (created by layer 'activation_78')>, <KerasTensor: shape=(None, None, None, 384) dtype=float32 (created by layer 'activation_79')>, <KerasTensor: shape=(None, None, None, 384) dtype=float32 (created by layer 'activation_82')>, <KerasTensor: shape=(None, None, None, 384) dtype=float32 (created by layer 'activation_83')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'batch_normalization_84')>, <KerasTensor: shape=(None, None, None, 320) dtype=float32 (created by layer 'activation_76')>, <KerasTensor: shape=(None, None, None, 768) dtype=float32 (created by layer 'mixed9_0')>, <KerasTensor: shape=(None, None, None, 768) dtype=float32 (created by layer 'concatenate')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'activation_84')>, <KerasTensor: shape=(None, None, None, 2048) dtype=float32 (created by layer 'mixed9')>, <KerasTensor: shape=(None, None, None, 2048) dtype=float32 (created by layer 'average_pooling2d_8')>, <KerasTensor: shape=(None, None, None, 192) dtype=float32 (created by layer 'conv2d_93')>]
For each style layer, calculate the gram matrix. Store these results in a list and return it.
def preprocess_image(image):
'''preprocesses a given image to use with Inception model'''
image = tf.cast(image, dtype=tf.float32)
image = (image / 127.5) - 1.0
return image
def get_style_image_features(image):
""" Get the style image features
Args:
image: an input image
Returns:
gram_style_features: the style features as gram matrices
"""
# preprocess the image using the given preprocessing function
preprocessed_style_image = preprocess_image(image)
# get the outputs from the inception model that you created using inception_model()
outputs = inception(preprocessed_style_image)
# Get just the style feature layers (exclude the content layer)
style_outputs = outputs[:NUM_STYLE_LAYERS]
# for each style layer, calculate the gram matrix for that layer and store these results in a list
gram_style_features = [gram_matrix(style_layer) for style_layer in style_outputs]
return gram_style_features
def get_content_image_features(image):
""" Get the content image features
Args:
image: an input image
Returns:
content_outputs: the content features of the image
"""
# preprocess the image
preprocessed_content_image = preprocess_image(image)
# get the outputs from the inception model
outputs = inception(preprocessed_content_image)
# get the content layer of the outputs
content_outputs = outputs[NUM_STYLE_LAYERS:]
return content_outputs
$L_{total} = \beta L_{style} + \alpha L_{content}$
def get_style_content_loss(style_targets, style_outputs, content_targets,
content_outputs, style_weight, content_weight):
""" Combine the style and content loss
Args:
style_targets: style features of the style image
style_outputs: style features of the generated image
content_targets: content features of the content image
content_outputs: content features of the generated image
style_weight: weight given to the style loss
content_weight: weight given to the content loss
Returns:
total_loss: the combined style and content loss
"""
# Sum of the style losses
style_loss = tf.add_n([ get_style_loss(style_output, style_target)
for style_output, style_target in zip(style_outputs, style_targets)])
# Sum up the content losses
content_loss = tf.add_n([get_content_loss(content_output, content_target)
for content_output, content_target in zip(content_outputs, content_targets)])
# scale the style loss by multiplying by the style weight and dividing by the number of style layers
style_loss = style_loss * style_weight / NUM_STYLE_LAYERS
# scale the content loss by multiplying by the content weight and dividing by the number of content layers
content_loss = content_loss * content_weight / NUM_CONTENT_LAYERS
# sum up the style and content losses
total_loss = style_loss + content_loss
# return the total loss
return total_loss
We use tf.GradientTape() to get the gradients of the loss with respect to the input image.
def calculate_gradients(image, style_targets, content_targets,
style_weight, content_weight):
""" Calculate the gradients of the loss with respect to the generated image
Args:
image: generated image
style_targets: style features of the style image
content_targets: content features of the content image
style_weight: weight given to the style loss
content_weight: weight given to the content loss
Returns:
gradients: gradients of the loss with respect to the input image
"""
with tf.GradientTape() as tape:
# get the style image features
style_features = get_style_image_features(image)
# get the content image features
content_features = get_content_image_features(image)
# get the style and content loss
loss = get_style_content_loss(style_targets, style_features, content_targets,
content_features, style_weight, content_weight)
# calculate gradients of loss with respect to the image
gradients = tape.gradient(loss, image)
return gradients
def clip_image_values(image, min_value=0.0, max_value=255.0):
'''clips the image pixel values by the given min and max'''
return tf.clip_by_value(image, clip_value_min=min_value, clip_value_max=max_value)
def update_image_with_style(image, style_targets, content_targets, style_weight,
content_weight, optimizer):
"""
Args:
image: generated image
style_targets: style features of the style image
content_targets: content features of the content image
style_weight: weight given to the style loss
content_weight: weight given to the content loss
optimizer: optimizer for updating the input image
"""
# Calculate gradients using the function that you just defined.
gradients = calculate_gradients(image, style_targets, content_targets,
style_weight, content_weight)
# apply the gradients to the given image
optimizer.apply_gradients([(gradients, image)])
# Clip the image using the given clip_image_values() function
image.assign(clip_image_values(image, min_value=0.0, max_value=255.0))
def tensor_to_image(tensor):
'''converts a tensor to an image'''
tensor_shape = tf.shape(tensor)
number_elem_shape = tf.shape(tensor_shape)
if number_elem_shape > 3:
assert tensor_shape[0] == 1
tensor = tensor[0]
return tf.keras.preprocessing.image.array_to_img(tensor)
def fit_style_transfer(style_image, content_image, style_weight=1e-2, content_weight=1e-4,
optimizer='adam', epochs=1, steps_per_epoch=1):
""" Performs neural style transfer.
Args:
style_image: image to get style features from
content_image: image to stylize
style_targets: style features of the style image
content_targets: content features of the content image
style_weight: weight given to the style loss
content_weight: weight given to the content loss
optimizer: optimizer for updating the input image
epochs: number of epochs
steps_per_epoch = steps per epoch
Returns:
generated_image: generated image at final epoch
images: collection of generated images per epoch
"""
images = []
step = 0
# get the style image features
style_targets = get_style_image_features(style_image)
# get the content image features
content_targets = get_content_image_features(content_image)
# initialize the generated image for updates
generated_image = tf.cast(content_image, dtype=tf.float32)
generated_image = tf.Variable(generated_image)
# collect the image updates starting from the content image
images.append(content_image)
for n in range(epochs):
for m in range(steps_per_epoch):
step += 1
# Update the image with the style using the function that we defined
update_image_with_style(generated_image, style_targets, content_targets,
style_weight, content_weight, optimizer)
print(".", end='')
if (m + 1) % 10 == 0:
images.append(generated_image)
# display the current stylized image
clear_output(wait=True)
display_image = tensor_to_image(generated_image)
display_fn(display_image)
# append to the image collection for visualization later
images.append(generated_image)
print("Train step: {}".format(step))
# convert to uint8 (expected dtype for images with pixels in the range [0,255])
generated_image = tf.cast(generated_image, dtype=tf.uint8)
return generated_image, images
With all the functions defined, we can now run the main loop and generate the stylized image.
# define style and content weight
style_weight = 0.35
content_weight = 1e-32
# define optimizer. learning rate decreases per epoch.
adam = tf.optimizers.Adam(tf.keras.optimizers.schedules.ExponentialDecay(initial_learning_rate=80.0, decay_steps=100, decay_rate=0.80))
# start the neural style transfer
stylized_image, display_images = fit_style_transfer(style_image=style_image, content_image=content_image,
style_weight=style_weight, content_weight=content_weight,
optimizer=adam, epochs=10, steps_per_epoch=100)
Train step: 1000
With higher style weight
# define style and content weight
style_weight = 0.8
content_weight = 1e-32
# define optimizer. learning rate decreases per epoch.
adam = tf.optimizers.Adam(tf.keras.optimizers.schedules.ExponentialDecay(initial_learning_rate=80.0, decay_steps=100, decay_rate=0.80))
# start the neural style transfer
stylized_image, display_images = fit_style_transfer(style_image=style_image, content_image=content_image,
style_weight=style_weight, content_weight=content_weight,
optimizer=adam, epochs=10, steps_per_epoch=100)
Train step: 1000
If you want to try different images, you may need to change the content and style weight for better results (specially the style weight.)