Training CNNs can take a lot of time, and a lot of data is required for that task. However, much of the time is spent to learn the best low-level filters that a network is using to extract patterns from images. A natural question arises - can we use a neural network trained on one dataset and adapt it to classifying different images without full training process?
This approach is called transfer learning, because we transfer some knowledge from one neural network model to another. In transfer learning, we typically start with a pre-trained model, which has been trained on some large image dataset, such as ImageNet. Those models can already do a good job extracting different features from generic images, and in many cases just building a classifier on top of those extracted features can yield a good result.
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
import numpy as np
import os
from tfcv import *
In this unit, we will solve a real-life problem of classifying images of cats and dogs. For this reason, we will use Kaggle Cats vs. Dogs Dataset, which can also be downloaded from Microsoft.
Let's download this dataset and extract it into data
directory (this process may take some time!):
if not os.path.exists('data/kagglecatsanddogs_5340.zip'):
!wget -P data https://download.microsoft.com/download/3/E/1/3E1C3F21-ECDB-4869-8368-6DEBA77B919F/kagglecatsanddogs_5340.zip
import zipfile
if not os.path.exists('data/PetImages'):
with zipfile.ZipFile('data/kagglecatsanddogs_5340.zip', 'r') as zip_ref:
zip_ref.extractall('data')
Unfortunately, there are some corrupt image files in the dataset. We need to do quick cleaning to check for corrupted files. In order not to clobber this tutorial, we moved the code to verify dataset into a module.
check_image_dir('data/PetImages/Cat/*.jpg')
check_image_dir('data/PetImages/Dog/*.jpg')
Corrupt image or wrong format: data/PetImages/Cat/12235.jpg Corrupt image or wrong format: data/PetImages/Cat/2663.jpg Corrupt image or wrong format: data/PetImages/Cat/4929.jpg Corrupt image or wrong format: data/PetImages/Cat/8183.jpg Corrupt image or wrong format: data/PetImages/Cat/11083.jpg Corrupt image or wrong format: data/PetImages/Cat/6435.jpg Corrupt image or wrong format: data/PetImages/Cat/6491.jpg Corrupt image or wrong format: data/PetImages/Cat/7968.jpg Corrupt image or wrong format: data/PetImages/Cat/6768.jpg Corrupt image or wrong format: data/PetImages/Cat/11397.jpg Corrupt image or wrong format: data/PetImages/Cat/8295.jpg Corrupt image or wrong format: data/PetImages/Cat/4874.jpg Corrupt image or wrong format: data/PetImages/Cat/23.jpg Corrupt image or wrong format: data/PetImages/Cat/11864.jpg Corrupt image or wrong format: data/PetImages/Cat/3491.jpg Corrupt image or wrong format: data/PetImages/Cat/11729.jpg Corrupt image or wrong format: data/PetImages/Cat/3197.jpg Corrupt image or wrong format: data/PetImages/Cat/10874.jpg Corrupt image or wrong format: data/PetImages/Cat/6376.jpg Corrupt image or wrong format: data/PetImages/Cat/9361.jpg Corrupt image or wrong format: data/PetImages/Cat/9328.jpg Corrupt image or wrong format: data/PetImages/Cat/910.jpg Corrupt image or wrong format: data/PetImages/Cat/2021.jpg Corrupt image or wrong format: data/PetImages/Cat/666.jpg Corrupt image or wrong format: data/PetImages/Cat/3710.jpg Corrupt image or wrong format: data/PetImages/Cat/9171.jpg Corrupt image or wrong format: data/PetImages/Cat/6906.jpg Corrupt image or wrong format: data/PetImages/Cat/11095.jpg Corrupt image or wrong format: data/PetImages/Cat/660.jpg Corrupt image or wrong format: data/PetImages/Cat/8415.jpg Corrupt image or wrong format: data/PetImages/Cat/3161.jpg Corrupt image or wrong format: data/PetImages/Cat/850.jpg Corrupt image or wrong format: data/PetImages/Cat/7003.jpg Corrupt image or wrong format: data/PetImages/Cat/1267.jpg Corrupt image or wrong format: data/PetImages/Cat/7642.jpg Corrupt image or wrong format: data/PetImages/Cat/8553.jpg Corrupt image or wrong format: data/PetImages/Cat/6900.jpg Corrupt image or wrong format: data/PetImages/Cat/4334.jpg Corrupt image or wrong format: data/PetImages/Cat/10404.jpg Corrupt image or wrong format: data/PetImages/Cat/6980.jpg Corrupt image or wrong format: data/PetImages/Cat/936.jpg Corrupt image or wrong format: data/PetImages/Cat/9619.jpg Corrupt image or wrong format: data/PetImages/Cat/3300.jpg Corrupt image or wrong format: data/PetImages/Cat/7647.jpg Corrupt image or wrong format: data/PetImages/Cat/2742.jpg Corrupt image or wrong format: data/PetImages/Cat/9565.jpg Corrupt image or wrong format: data/PetImages/Cat/1757.jpg Corrupt image or wrong format: data/PetImages/Cat/11874.jpg Corrupt image or wrong format: data/PetImages/Cat/1936.jpg Corrupt image or wrong format: data/PetImages/Cat/5077.jpg Corrupt image or wrong format: data/PetImages/Cat/7845.jpg Corrupt image or wrong format: data/PetImages/Cat/8832.jpg Corrupt image or wrong format: data/PetImages/Cat/11935.jpg Corrupt image or wrong format: data/PetImages/Cat/9208.jpg Corrupt image or wrong format: data/PetImages/Cat/4322.jpg Corrupt image or wrong format: data/PetImages/Cat/7978.jpg Corrupt image or wrong format: data/PetImages/Cat/10820.jpg Corrupt image or wrong format: data/PetImages/Cat/391.jpg Corrupt image or wrong format: data/PetImages/Cat/10073.jpg Corrupt image or wrong format: data/PetImages/Cat/8958.jpg Corrupt image or wrong format: data/PetImages/Cat/8470.jpg Corrupt image or wrong format: data/PetImages/Cat/445.jpg Corrupt image or wrong format: data/PetImages/Cat/4821.jpg Corrupt image or wrong format: data/PetImages/Cat/11565.jpg Corrupt image or wrong format: data/PetImages/Cat/4351.jpg Corrupt image or wrong format: data/PetImages/Cat/4833.jpg Corrupt image or wrong format: data/PetImages/Cat/140.jpg Corrupt image or wrong format: data/PetImages/Cat/3153.jpg Corrupt image or wrong format: data/PetImages/Cat/5370.jpg Corrupt image or wrong format: data/PetImages/Cat/5819.jpg Corrupt image or wrong format: data/PetImages/Cat/3649.jpg Corrupt image or wrong format: data/PetImages/Cat/4293.jpg Corrupt image or wrong format: data/PetImages/Cat/3967.jpg Corrupt image or wrong format: data/PetImages/Cat/1937.jpg Corrupt image or wrong format: data/PetImages/Cat/10125.jpg Corrupt image or wrong format: data/PetImages/Cat/1386.jpg Corrupt image or wrong format: data/PetImages/Cat/4750.jpg Corrupt image or wrong format: data/PetImages/Cat/6029.jpg Corrupt image or wrong format: data/PetImages/Cat/5614.jpg Corrupt image or wrong format: data/PetImages/Cat/2569.jpg Corrupt image or wrong format: data/PetImages/Cat/9778.jpg Corrupt image or wrong format: data/PetImages/Cat/6486.jpg Corrupt image or wrong format: data/PetImages/Cat/11210.jpg Corrupt image or wrong format: data/PetImages/Cat/7502.jpg Corrupt image or wrong format: data/PetImages/Cat/2189.jpg Corrupt image or wrong format: data/PetImages/Cat/1151.jpg Corrupt image or wrong format: data/PetImages/Cat/4629.jpg Corrupt image or wrong format: data/PetImages/Cat/12080.jpg Corrupt image or wrong format: data/PetImages/Cat/10501.jpg Corrupt image or wrong format: data/PetImages/Cat/1914.jpg Corrupt image or wrong format: data/PetImages/Cat/11086.jpg Corrupt image or wrong format: data/PetImages/Cat/4000.jpg Corrupt image or wrong format: data/PetImages/Cat/9100.jpg Corrupt image or wrong format: data/PetImages/Cat/5553.jpg Corrupt image or wrong format: data/PetImages/Cat/12269.jpg Corrupt image or wrong format: data/PetImages/Dog/5243.jpg Corrupt image or wrong format: data/PetImages/Dog/3288.jpg Corrupt image or wrong format: data/PetImages/Dog/11912.jpg Corrupt image or wrong format: data/PetImages/Dog/3320.jpg Corrupt image or wrong format: data/PetImages/Dog/11166.jpg Corrupt image or wrong format: data/PetImages/Dog/8194.jpg Corrupt image or wrong format: data/PetImages/Dog/2353.jpg Corrupt image or wrong format: data/PetImages/Dog/9026.jpg Corrupt image or wrong format: data/PetImages/Dog/8557.jpg Corrupt image or wrong format: data/PetImages/Dog/7969.jpg Corrupt image or wrong format: data/PetImages/Dog/8693.jpg Corrupt image or wrong format: data/PetImages/Dog/4367.jpg Corrupt image or wrong format: data/PetImages/Dog/1900.jpg Corrupt image or wrong format: data/PetImages/Dog/5730.jpg Corrupt image or wrong format: data/PetImages/Dog/9961.jpg Corrupt image or wrong format: data/PetImages/Dog/8126.jpg Corrupt image or wrong format: data/PetImages/Dog/8521.jpg Corrupt image or wrong format: data/PetImages/Dog/11560.jpg Corrupt image or wrong format: data/PetImages/Dog/11233.jpg Corrupt image or wrong format: data/PetImages/Dog/11675.jpg Corrupt image or wrong format: data/PetImages/Dog/8889.jpg Corrupt image or wrong format: data/PetImages/Dog/9414.jpg Corrupt image or wrong format: data/PetImages/Dog/3588.jpg Corrupt image or wrong format: data/PetImages/Dog/10353.jpg Corrupt image or wrong format: data/PetImages/Dog/10733.jpg Corrupt image or wrong format: data/PetImages/Dog/6238.jpg Corrupt image or wrong format: data/PetImages/Dog/5618.jpg Corrupt image or wrong format: data/PetImages/Dog/7743.jpg Corrupt image or wrong format: data/PetImages/Dog/10747.jpg Corrupt image or wrong format: data/PetImages/Dog/6503.jpg Corrupt image or wrong format: data/PetImages/Dog/11410.jpg Corrupt image or wrong format: data/PetImages/Dog/5955.jpg Corrupt image or wrong format: data/PetImages/Dog/2915.jpg Corrupt image or wrong format: data/PetImages/Dog/3136.jpg Corrupt image or wrong format: data/PetImages/Dog/4086.jpg Corrupt image or wrong format: data/PetImages/Dog/2905.jpg Corrupt image or wrong format: data/PetImages/Dog/2494.jpg Corrupt image or wrong format: data/PetImages/Dog/3546.jpg Corrupt image or wrong format: data/PetImages/Dog/5547.jpg Corrupt image or wrong format: data/PetImages/Dog/9851.jpg Corrupt image or wrong format: data/PetImages/Dog/9188.jpg Corrupt image or wrong format: data/PetImages/Dog/4203.jpg Corrupt image or wrong format: data/PetImages/Dog/8715.jpg Corrupt image or wrong format: data/PetImages/Dog/4640.jpg Corrupt image or wrong format: data/PetImages/Dog/1866.jpg Corrupt image or wrong format: data/PetImages/Dog/7718.jpg Corrupt image or wrong format: data/PetImages/Dog/7514.jpg Corrupt image or wrong format: data/PetImages/Dog/1017.jpg Corrupt image or wrong format: data/PetImages/Dog/12102.jpg Corrupt image or wrong format: data/PetImages/Dog/296.jpg Corrupt image or wrong format: data/PetImages/Dog/6718.jpg Corrupt image or wrong format: data/PetImages/Dog/11253.jpg Corrupt image or wrong format: data/PetImages/Dog/10173.jpg Corrupt image or wrong format: data/PetImages/Dog/11849.jpg Corrupt image or wrong format: data/PetImages/Dog/10383.jpg Corrupt image or wrong format: data/PetImages/Dog/10678.jpg Corrupt image or wrong format: data/PetImages/Dog/5790.jpg Corrupt image or wrong format: data/PetImages/Dog/6318.jpg Corrupt image or wrong format: data/PetImages/Dog/1884.jpg Corrupt image or wrong format: data/PetImages/Dog/4654.jpg Corrupt image or wrong format: data/PetImages/Dog/9145.jpg Corrupt image or wrong format: data/PetImages/Dog/2479.jpg Corrupt image or wrong format: data/PetImages/Dog/12289.jpg Corrupt image or wrong format: data/PetImages/Dog/2688.jpg Corrupt image or wrong format: data/PetImages/Dog/4134.jpg Corrupt image or wrong format: data/PetImages/Dog/10637.jpg
/anaconda/envs/py38_tensorflow/lib/python3.8/site-packages/PIL/TiffImagePlugin.py:793: UserWarning: Truncated File Read warnings.warn(str(msg))
Corrupt image or wrong format: data/PetImages/Dog/543.jpg Corrupt image or wrong format: data/PetImages/Dog/8730.jpg Corrupt image or wrong format: data/PetImages/Dog/12114.jpg Corrupt image or wrong format: data/PetImages/Dog/522.jpg Corrupt image or wrong format: data/PetImages/Dog/10863.jpg Corrupt image or wrong format: data/PetImages/Dog/10401.jpg Corrupt image or wrong format: data/PetImages/Dog/7739.jpg Corrupt image or wrong format: data/PetImages/Dog/561.jpg Corrupt image or wrong format: data/PetImages/Dog/1308.jpg Corrupt image or wrong format: data/PetImages/Dog/9643.jpg Corrupt image or wrong format: data/PetImages/Dog/3155.jpg Corrupt image or wrong format: data/PetImages/Dog/5736.jpg Corrupt image or wrong format: data/PetImages/Dog/4257.jpg Corrupt image or wrong format: data/PetImages/Dog/11853.jpg Corrupt image or wrong format: data/PetImages/Dog/6855.jpg Corrupt image or wrong format: data/PetImages/Dog/1259.jpg Corrupt image or wrong format: data/PetImages/Dog/7652.jpg Corrupt image or wrong format: data/PetImages/Dog/573.jpg Corrupt image or wrong format: data/PetImages/Dog/11702.jpg Corrupt image or wrong format: data/PetImages/Dog/414.jpg Corrupt image or wrong format: data/PetImages/Dog/4301.jpg Corrupt image or wrong format: data/PetImages/Dog/9500.jpg Corrupt image or wrong format: data/PetImages/Dog/10726.jpg Corrupt image or wrong format: data/PetImages/Dog/9367.jpg Corrupt image or wrong format: data/PetImages/Dog/2877.jpg Corrupt image or wrong format: data/PetImages/Dog/10907.jpg Corrupt image or wrong format: data/PetImages/Dog/10972.jpg Corrupt image or wrong format: data/PetImages/Dog/7311.jpg Corrupt image or wrong format: data/PetImages/Dog/10797.jpg Corrupt image or wrong format: data/PetImages/Dog/2317.jpg Corrupt image or wrong format: data/PetImages/Dog/7128.jpg Corrupt image or wrong format: data/PetImages/Dog/6500.jpg Corrupt image or wrong format: data/PetImages/Dog/11285.jpg Corrupt image or wrong format: data/PetImages/Dog/6430.jpg Corrupt image or wrong format: data/PetImages/Dog/6032.jpg Corrupt image or wrong format: data/PetImages/Dog/10969.jpg Corrupt image or wrong format: data/PetImages/Dog/8364.jpg Corrupt image or wrong format: data/PetImages/Dog/11692.jpg Corrupt image or wrong format: data/PetImages/Dog/3038.jpg Corrupt image or wrong format: data/PetImages/Dog/5604.jpg Corrupt image or wrong format: data/PetImages/Dog/565.jpg Corrupt image or wrong format: data/PetImages/Dog/9640.jpg Corrupt image or wrong format: data/PetImages/Dog/7459.jpg Corrupt image or wrong format: data/PetImages/Dog/6305.jpg Corrupt image or wrong format: data/PetImages/Dog/6555.jpg Corrupt image or wrong format: data/PetImages/Dog/11590.jpg Corrupt image or wrong format: data/PetImages/Dog/6059.jpg Corrupt image or wrong format: data/PetImages/Dog/3927.jpg Corrupt image or wrong format: data/PetImages/Dog/10705.jpg Corrupt image or wrong format: data/PetImages/Dog/2384.jpg Corrupt image or wrong format: data/PetImages/Dog/3885.jpg Corrupt image or wrong format: data/PetImages/Dog/663.jpg Corrupt image or wrong format: data/PetImages/Dog/9043.jpg Corrupt image or wrong format: data/PetImages/Dog/8563.jpg Corrupt image or wrong format: data/PetImages/Dog/9556.jpg Corrupt image or wrong format: data/PetImages/Dog/9967.jpg Corrupt image or wrong format: data/PetImages/Dog/10158.jpg Corrupt image or wrong format: data/PetImages/Dog/5263.jpg Corrupt image or wrong format: data/PetImages/Dog/8641.jpg Corrupt image or wrong format: data/PetImages/Dog/7369.jpg Corrupt image or wrong format: data/PetImages/Dog/4924.jpg Corrupt image or wrong format: data/PetImages/Dog/10351.jpg Corrupt image or wrong format: data/PetImages/Dog/6213.jpg Corrupt image or wrong format: data/PetImages/Dog/5104.jpg Corrupt image or wrong format: data/PetImages/Dog/1356.jpg Corrupt image or wrong format: data/PetImages/Dog/7133.jpg Corrupt image or wrong format: data/PetImages/Dog/7112.jpg Corrupt image or wrong format: data/PetImages/Dog/1168.jpg Corrupt image or wrong format: data/PetImages/Dog/719.jpg Corrupt image or wrong format: data/PetImages/Dog/50.jpg
In previous examples, we were loading datasets that are built into Keras. Now we are about to deal with our own dataset, which we need to load from a directory of images.
In real life, the size of image datasets can be pretty large, and one cannot rely on all data being able to fit into memory. Thus, datasets are often represented as generators that can return data in minibatches suitable for training.
To deal with image classification, Keras includes special function image_dataset_from_directory
, which can load images from subdirectories corresponding to different classes. This function also takes care of scaling images, and it can also split dataset into train and test subsets:
data_dir = 'data/PetImages'
batch_size = 64
ds_train = keras.preprocessing.image_dataset_from_directory(
data_dir,
validation_split = 0.2,
subset = 'training',
seed = 13,
image_size = (224,224),
batch_size = batch_size
)
ds_test = keras.preprocessing.image_dataset_from_directory(
data_dir,
validation_split = 0.2,
subset = 'validation',
seed = 13,
image_size = (224,224),
batch_size = batch_size
)
Found 24769 files belonging to 2 classes. Using 19816 files for training. Found 24769 files belonging to 2 classes. Using 4953 files for validation.
It is important to set the same seed
value for both calls, because it affects the split of images between train and test dataset.
Dataset automatically picks up class names from directories, and you can access them if needed by calling:
ds_train.class_names
['Cat', 'Dog']
Datasets that we have obtained can be directly passed to fit
function to train the model. They contain both corresponding images and labels, which can be looped over using the following construction:
for x,y in ds_train:
print(f"Training batch shape: features={x.shape}, labels={y.shape}")
x_sample, y_sample = x,y
break
display_dataset(x_sample.numpy().astype(np.int),np.expand_dims(y_sample,1),classes=ds_train.class_names)
Training batch shape: features=(64, 224, 224, 3), labels=(64,)
Note: All images in the dataset are represented as floatint point tensors with range 0-255. Before passing them to the neural network, we need to scale those values into 0-1 range. When plotting images, we either need to do the same, or convert values to the
int
type (which we do in the code above), in order to showmatplotlib
that we want to plot the original unscaled image.
For many image classification tasks one can find pre-trained neural network models. Many of those models are available inside keras.applications
namespace, and even more models can be found on the Internet. Let's see how simplest VGG-16 model can be loaded and used:
vgg = keras.applications.VGG16()
inp = keras.applications.vgg16.preprocess_input(x_sample[:1])
res = vgg(inp)
print(f"Most probable class = {tf.argmax(res,1)}")
keras.applications.vgg16.decode_predictions(res.numpy())
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/vgg16/vgg16_weights_tf_dim_ordering_tf_kernels.h5 553467904/553467096 [==============================] - 6s 0us/step Most probable class = [208] Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/imagenet_class_index.json 40960/35363 [==================================] - 0s 0us/step
[[('n02099712', 'Labrador_retriever', 0.5340957), ('n02100236', 'German_short-haired_pointer', 0.0939442), ('n02092339', 'Weimaraner', 0.08160535), ('n02099849', 'Chesapeake_Bay_retriever', 0.057179328), ('n02109047', 'Great_Dane', 0.03733857)]]
There are a couple of important things here:
preprocess_input
function, which receives a batch of images, and returns their processed form. In the case of VGG-16, images are normalized, and some pre-defined avarage value for each channels is subtracted. That is because VGG-16 was originally trained with this pre-processing.argmax
on this tensor.ImageNet
class. To make sense of this result, we can also use decode_predictions
function, that returns top n classes together with their names.Let's also see the architecture of the VGG-16 network:
vgg.summary()
Model: "vgg16" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_1 (InputLayer) [(None, 224, 224, 3)] 0 _________________________________________________________________ block1_conv1 (Conv2D) (None, 224, 224, 64) 1792 _________________________________________________________________ block1_conv2 (Conv2D) (None, 224, 224, 64) 36928 _________________________________________________________________ block1_pool (MaxPooling2D) (None, 112, 112, 64) 0 _________________________________________________________________ block2_conv1 (Conv2D) (None, 112, 112, 128) 73856 _________________________________________________________________ block2_conv2 (Conv2D) (None, 112, 112, 128) 147584 _________________________________________________________________ block2_pool (MaxPooling2D) (None, 56, 56, 128) 0 _________________________________________________________________ block3_conv1 (Conv2D) (None, 56, 56, 256) 295168 _________________________________________________________________ block3_conv2 (Conv2D) (None, 56, 56, 256) 590080 _________________________________________________________________ block3_conv3 (Conv2D) (None, 56, 56, 256) 590080 _________________________________________________________________ block3_pool (MaxPooling2D) (None, 28, 28, 256) 0 _________________________________________________________________ block4_conv1 (Conv2D) (None, 28, 28, 512) 1180160 _________________________________________________________________ block4_conv2 (Conv2D) (None, 28, 28, 512) 2359808 _________________________________________________________________ block4_conv3 (Conv2D) (None, 28, 28, 512) 2359808 _________________________________________________________________ block4_pool (MaxPooling2D) (None, 14, 14, 512) 0 _________________________________________________________________ block5_conv1 (Conv2D) (None, 14, 14, 512) 2359808 _________________________________________________________________ block5_conv2 (Conv2D) (None, 14, 14, 512) 2359808 _________________________________________________________________ block5_conv3 (Conv2D) (None, 14, 14, 512) 2359808 _________________________________________________________________ block5_pool (MaxPooling2D) (None, 7, 7, 512) 0 _________________________________________________________________ flatten (Flatten) (None, 25088) 0 _________________________________________________________________ fc1 (Dense) (None, 4096) 102764544 _________________________________________________________________ fc2 (Dense) (None, 4096) 16781312 _________________________________________________________________ predictions (Dense) (None, 1000) 4097000 ================================================================= Total params: 138,357,544 Trainable params: 138,357,544 Non-trainable params: 0 _________________________________________________________________
Deep neural networks, such as VGG-16 and other more modern architectures require quite a lot of computational power to run. It makes sense to use GPU acceleration, if it is available. Luckily, Keras automatically speeds up the computatons on the GPU if it is available. We can check if Tensorflow is able to use GPU using the following code:
tf.config.list_physical_devices('GPU')
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
If we want to use VGG-16 to extract features from our images, we need the model without final classification layers. We can instantiate VGG-16 model without top layers using this code:
vgg = keras.applications.VGG16(include_top=False)
inp = keras.applications.vgg16.preprocess_input(x_sample[:1])
res = vgg(inp)
print(f"Shape after applying VGG-16: {res[0].shape}")
plt.figure(figsize=(15,3))
plt.imshow(res[0].numpy().reshape(-1,512))
Shape after applying VGG-16: (7, 7, 512)
<matplotlib.image.AxesImage at 0x7fafcc685ac0>
The dimension of feature tensor is 7x7x512, but in order to visualize it we had to reshape it to 2D form.
Now let's try to see if those features can be used to classify images. Let's manually take some portion of images (50 minibatches, in our case), and pre-compute their feature vectors. We can use Tensorflow dataset API to do that. map
function takes a dataset and applies a given lambda-function to transform it. We use this mechanism to construct new datasets, ds_features_train
and ds_features_test
, that contain VGG-extracted features instead of original images.
num = batch_size*50
ds_features_train = ds_train.take(50).map(lambda x,y : (vgg(x),y))
ds_features_test = ds_test.take(10).map(lambda x,y : (vgg(x),y))
for x,y in ds_features_train:
print(x.shape,y.shape)
break
(64, 7, 7, 512) (64,)
We used construction .take(50)
to limit the dataset size, to speed up our demonstration. You can of course perform this experiment on the full dataset.
Now that we have a dataset with extracted features, we can train a simple dense classifier to distinguish between cats and dogs. This network will take feature vector of shape (7,7,512), and produce one output that corresponds either to a dog or to a cat. Because it is a binary classification, we use sigmoid
activation function and binary_crossentropy
loss.
model = keras.models.Sequential([
keras.layers.Flatten(input_shape=(7,7,512)),
keras.layers.Dense(1,activation='sigmoid')
])
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['acc'])
hist = model.fit(ds_features_train, validation_data=ds_features_test)
50/50 [==============================] - 1896s 38s/step - loss: 1.4845 - acc: 0.9144 - val_loss: 0.7220 - val_acc: 0.9516
The result is great, we can distinguish between a cat and a dog with almost 95% probability! However, we have only tested this approach on a subset of all images, because manual feature extraction seems to take a lot of time.
We can also avoid manually pre-computing the features by using the original VGG-16 network as a whole during training, by adding feature extractor to our network as a first layer.
The beauty of Keras architecture is that VGG-16 model that we have defined above can also be used as a layer in another neural network! We just need to construct a network with dense classifier on top of it, and then train the whole network using back propagation.
model = keras.models.Sequential()
model.add(keras.applications.VGG16(include_top=False,input_shape=(224,224,3)))
model.add(keras.layers.Flatten())
model.add(keras.layers.Dense(1,activation='sigmoid'))
model.layers[0].trainable = False
model.summary()
Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= vgg16 (Functional) (None, 7, 7, 512) 14714688 _________________________________________________________________ flatten (Flatten) (None, 25088) 0 _________________________________________________________________ dense (Dense) (None, 1) 25089 ================================================================= Total params: 14,739,777 Trainable params: 25,089 Non-trainable params: 14,714,688 _________________________________________________________________
This model looks like and end-to-end classification network, which takes an image and returns the class. However, the tricky thing is that we want VGG16 to act as a feature extractor, and not to be re-trained. Thus, we need to freeze weights of convolutional feature extractor. We can access first layer of the network by calling model.layers[0]
, and we just need to set trainable
property to False
.
Note: Freezing of feature extractor weights is needed, because otherwise untrained classifier layer can destroy the original pre-trained weights of convolutional extractor.
You can notice that while the total number of parameters in our network is around 15 million, we are only training 25k parameters. All other parameters of top-level convolutional filters are pre-trained. That is good, because we are able to fine-tune smaller number of parameters with smaller number of examples.
We will now train our network and see how good we can get. Expect rather long running time, and do not worry if the execution seems frozen for some time.
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['acc'])
hist = model.fit(ds_train, validation_data=ds_test)
310/310 [==============================] - 265s 716ms/step - loss: 0.9917 - acc: 0.9512 - val_loss: 0.8156 - val_acc: 0.9671
It looks like we have obtained reasonably accurate cats vs. dogs classifier!
Once we have trained the model, we can save model architecture and trained weights to a file for future use:
model.save('data/cats_dogs.tf')
INFO:tensorflow:Assets written to: data/cats_dogs.tf/assets
We can then load the model from file at any time. You may find it useful in case the next experiment destroys the model - you would not have to re-start from scratch.
model = keras.models.load_model('data/cats_dogs.tf')
In the previous section, we have trained the final classifier layer to classify images in our own dataset. However, we did not re-train the feature extractor, and our model relied on the features that the model has learned on ImageNet data. If your objects visually differ from ordinary ImageNet images, this combination of features might not work best. Thus it makes sense to start training convolutional layers as well.
To do that, we can unfreeze the convolutional filter parameters that we have previously frozen.
Note: It is important that you freeze parameters first and perform several epochs of training in order to stabilize weights in the classification layer. If you immediately start training end-to-end network with unfrozen parameters, large errors are likely to destroy the pre-trained weights in the convolutional layers.
Our convolutional VGG-16 model is located inside the first layer, and it consists of many layers itself. We can have a look at its structure:
model.layers[0].summary()
Model: "vgg16" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_1 (InputLayer) [(None, 224, 224, 3)] 0 _________________________________________________________________ block1_conv1 (Conv2D) (None, 224, 224, 64) 1792 _________________________________________________________________ block1_conv2 (Conv2D) (None, 224, 224, 64) 36928 _________________________________________________________________ block1_pool (MaxPooling2D) (None, 112, 112, 64) 0 _________________________________________________________________ block2_conv1 (Conv2D) (None, 112, 112, 128) 73856 _________________________________________________________________ block2_conv2 (Conv2D) (None, 112, 112, 128) 147584 _________________________________________________________________ block2_pool (MaxPooling2D) (None, 56, 56, 128) 0 _________________________________________________________________ block3_conv1 (Conv2D) (None, 56, 56, 256) 295168 _________________________________________________________________ block3_conv2 (Conv2D) (None, 56, 56, 256) 590080 _________________________________________________________________ block3_conv3 (Conv2D) (None, 56, 56, 256) 590080 _________________________________________________________________ block3_pool (MaxPooling2D) (None, 28, 28, 256) 0 _________________________________________________________________ block4_conv1 (Conv2D) (None, 28, 28, 512) 1180160 _________________________________________________________________ block4_conv2 (Conv2D) (None, 28, 28, 512) 2359808 _________________________________________________________________ block4_conv3 (Conv2D) (None, 28, 28, 512) 2359808 _________________________________________________________________ block4_pool (MaxPooling2D) (None, 14, 14, 512) 0 _________________________________________________________________ block5_conv1 (Conv2D) (None, 14, 14, 512) 2359808 _________________________________________________________________ block5_conv2 (Conv2D) (None, 14, 14, 512) 2359808 _________________________________________________________________ block5_conv3 (Conv2D) (None, 14, 14, 512) 2359808 _________________________________________________________________ block5_pool (MaxPooling2D) (None, 7, 7, 512) 0 ================================================================= Total params: 14,714,688 Trainable params: 0 Non-trainable params: 14,714,688 _________________________________________________________________
We can unfreeze all layers of convolutional base:
model.layers[0].trainable = True
However, unfeezing all of them at once is not the best idea. We can first unfreeze just a few final layers of convolutions, because they contain higher level patterns that are relevant for our images. For example, to begin with, we can freeze all layers except the last 4:
for i in range(len(model.layers[0].layers)-4):
model.layers[0].layers[i].trainable = False
model.summary()
Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= vgg16 (Functional) (None, 7, 7, 512) 14714688 _________________________________________________________________ flatten (Flatten) (None, 25088) 0 _________________________________________________________________ dense (Dense) (None, 1) 25089 ================================================================= Total params: 14,739,777 Trainable params: 7,104,513 Non-trainable params: 7,635,264 _________________________________________________________________
Observe that the number of trainable parameters increased significantly, but it is still around 50% of all parameters.
After unfreezing, we can do a few more epochs of training (in our example, we will do just one). You can also select lower learning rate, in order to minimize the impact on the pre-trained weights. However, even with low learning rate, you can expect the accuracy to drop in the beginning of the training, until finally reaching slightly higher level than in the case of fixed weights.
Note: This training happens much slower, because we need to propagate gradients back through many layers of the network!
hist = model.fit(ds_train, validation_data=ds_test)
310/310 [==============================] - 201s 645ms/step - loss: 0.5270 - acc: 0.9776 - val_loss: 1.4132 - val_acc: 0.9653
We are likely to achieve higher training accuracy, because we are using more poweful network with more parameters, but validation accuracy would increase not as much.
Feel free to unfreeze a few more layers of the network and train more, to see if you are able to achieve higher accuracy!
VGG-16 is one of the simplest computer vision architectures. Keras provides many more pre-trained networks. The most frequently used ones among those are ResNet architectures, developed by Microsoft, and Inception by Google. For example, let's explore the architecture of the simplest ResNet-50 model (ResNet is a family of models with different depth, you can try experimenting with ResNet-152 if you want to see what a really deep model looks like):
resnet = keras.applications.ResNet50()
resnet.summary()
Model: "resnet50" __________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== input_3 (InputLayer) [(None, 224, 224, 3) 0 __________________________________________________________________________________________________ conv1_pad (ZeroPadding2D) (None, 230, 230, 3) 0 input_3[0][0] __________________________________________________________________________________________________ conv1_conv (Conv2D) (None, 112, 112, 64) 9472 conv1_pad[0][0] __________________________________________________________________________________________________ conv1_bn (BatchNormalization) (None, 112, 112, 64) 256 conv1_conv[0][0] __________________________________________________________________________________________________ conv1_relu (Activation) (None, 112, 112, 64) 0 conv1_bn[0][0] __________________________________________________________________________________________________ pool1_pad (ZeroPadding2D) (None, 114, 114, 64) 0 conv1_relu[0][0] __________________________________________________________________________________________________ pool1_pool (MaxPooling2D) (None, 56, 56, 64) 0 pool1_pad[0][0] __________________________________________________________________________________________________ conv2_block1_1_conv (Conv2D) (None, 56, 56, 64) 4160 pool1_pool[0][0] __________________________________________________________________________________________________ conv2_block1_1_bn (BatchNormali (None, 56, 56, 64) 256 conv2_block1_1_conv[0][0] __________________________________________________________________________________________________ conv2_block1_1_relu (Activation (None, 56, 56, 64) 0 conv2_block1_1_bn[0][0] __________________________________________________________________________________________________ conv2_block1_2_conv (Conv2D) (None, 56, 56, 64) 36928 conv2_block1_1_relu[0][0] __________________________________________________________________________________________________ conv2_block1_2_bn (BatchNormali (None, 56, 56, 64) 256 conv2_block1_2_conv[0][0] __________________________________________________________________________________________________ conv2_block1_2_relu (Activation (None, 56, 56, 64) 0 conv2_block1_2_bn[0][0] __________________________________________________________________________________________________ conv2_block1_0_conv (Conv2D) (None, 56, 56, 256) 16640 pool1_pool[0][0] __________________________________________________________________________________________________ conv2_block1_3_conv (Conv2D) (None, 56, 56, 256) 16640 conv2_block1_2_relu[0][0] __________________________________________________________________________________________________ conv2_block1_0_bn (BatchNormali (None, 56, 56, 256) 1024 conv2_block1_0_conv[0][0] __________________________________________________________________________________________________ conv2_block1_3_bn (BatchNormali (None, 56, 56, 256) 1024 conv2_block1_3_conv[0][0] __________________________________________________________________________________________________ conv2_block1_add (Add) (None, 56, 56, 256) 0 conv2_block1_0_bn[0][0] conv2_block1_3_bn[0][0] __________________________________________________________________________________________________ conv2_block1_out (Activation) (None, 56, 56, 256) 0 conv2_block1_add[0][0] __________________________________________________________________________________________________ conv2_block2_1_conv (Conv2D) (None, 56, 56, 64) 16448 conv2_block1_out[0][0] __________________________________________________________________________________________________ conv2_block2_1_bn (BatchNormali (None, 56, 56, 64) 256 conv2_block2_1_conv[0][0] __________________________________________________________________________________________________ conv2_block2_1_relu (Activation (None, 56, 56, 64) 0 conv2_block2_1_bn[0][0] __________________________________________________________________________________________________ conv2_block2_2_conv (Conv2D) (None, 56, 56, 64) 36928 conv2_block2_1_relu[0][0] __________________________________________________________________________________________________ conv2_block2_2_bn (BatchNormali (None, 56, 56, 64) 256 conv2_block2_2_conv[0][0] __________________________________________________________________________________________________ conv2_block2_2_relu (Activation (None, 56, 56, 64) 0 conv2_block2_2_bn[0][0] __________________________________________________________________________________________________ conv2_block2_3_conv (Conv2D) (None, 56, 56, 256) 16640 conv2_block2_2_relu[0][0] __________________________________________________________________________________________________ conv2_block2_3_bn (BatchNormali (None, 56, 56, 256) 1024 conv2_block2_3_conv[0][0] __________________________________________________________________________________________________ conv2_block2_add (Add) (None, 56, 56, 256) 0 conv2_block1_out[0][0] conv2_block2_3_bn[0][0] __________________________________________________________________________________________________ conv2_block2_out (Activation) (None, 56, 56, 256) 0 conv2_block2_add[0][0] __________________________________________________________________________________________________ conv2_block3_1_conv (Conv2D) (None, 56, 56, 64) 16448 conv2_block2_out[0][0] __________________________________________________________________________________________________ conv2_block3_1_bn (BatchNormali (None, 56, 56, 64) 256 conv2_block3_1_conv[0][0] __________________________________________________________________________________________________ conv2_block3_1_relu (Activation (None, 56, 56, 64) 0 conv2_block3_1_bn[0][0] __________________________________________________________________________________________________ conv2_block3_2_conv (Conv2D) (None, 56, 56, 64) 36928 conv2_block3_1_relu[0][0] __________________________________________________________________________________________________ conv2_block3_2_bn (BatchNormali (None, 56, 56, 64) 256 conv2_block3_2_conv[0][0] __________________________________________________________________________________________________ conv2_block3_2_relu (Activation (None, 56, 56, 64) 0 conv2_block3_2_bn[0][0] __________________________________________________________________________________________________ conv2_block3_3_conv (Conv2D) (None, 56, 56, 256) 16640 conv2_block3_2_relu[0][0] __________________________________________________________________________________________________ conv2_block3_3_bn (BatchNormali (None, 56, 56, 256) 1024 conv2_block3_3_conv[0][0] __________________________________________________________________________________________________ conv2_block3_add (Add) (None, 56, 56, 256) 0 conv2_block2_out[0][0] conv2_block3_3_bn[0][0] __________________________________________________________________________________________________ conv2_block3_out (Activation) (None, 56, 56, 256) 0 conv2_block3_add[0][0] __________________________________________________________________________________________________ conv3_block1_1_conv (Conv2D) (None, 28, 28, 128) 32896 conv2_block3_out[0][0] __________________________________________________________________________________________________ conv3_block1_1_bn (BatchNormali (None, 28, 28, 128) 512 conv3_block1_1_conv[0][0] __________________________________________________________________________________________________ conv3_block1_1_relu (Activation (None, 28, 28, 128) 0 conv3_block1_1_bn[0][0] __________________________________________________________________________________________________ conv3_block1_2_conv (Conv2D) (None, 28, 28, 128) 147584 conv3_block1_1_relu[0][0] __________________________________________________________________________________________________ conv3_block1_2_bn (BatchNormali (None, 28, 28, 128) 512 conv3_block1_2_conv[0][0] __________________________________________________________________________________________________ conv3_block1_2_relu (Activation (None, 28, 28, 128) 0 conv3_block1_2_bn[0][0] __________________________________________________________________________________________________ conv3_block1_0_conv (Conv2D) (None, 28, 28, 512) 131584 conv2_block3_out[0][0] __________________________________________________________________________________________________ conv3_block1_3_conv (Conv2D) (None, 28, 28, 512) 66048 conv3_block1_2_relu[0][0] __________________________________________________________________________________________________ conv3_block1_0_bn (BatchNormali (None, 28, 28, 512) 2048 conv3_block1_0_conv[0][0] __________________________________________________________________________________________________ conv3_block1_3_bn (BatchNormali (None, 28, 28, 512) 2048 conv3_block1_3_conv[0][0] __________________________________________________________________________________________________ conv3_block1_add (Add) (None, 28, 28, 512) 0 conv3_block1_0_bn[0][0] conv3_block1_3_bn[0][0] __________________________________________________________________________________________________ conv3_block1_out (Activation) (None, 28, 28, 512) 0 conv3_block1_add[0][0] __________________________________________________________________________________________________ conv3_block2_1_conv (Conv2D) (None, 28, 28, 128) 65664 conv3_block1_out[0][0] __________________________________________________________________________________________________ conv3_block2_1_bn (BatchNormali (None, 28, 28, 128) 512 conv3_block2_1_conv[0][0] __________________________________________________________________________________________________ conv3_block2_1_relu (Activation (None, 28, 28, 128) 0 conv3_block2_1_bn[0][0] __________________________________________________________________________________________________ conv3_block2_2_conv (Conv2D) (None, 28, 28, 128) 147584 conv3_block2_1_relu[0][0] __________________________________________________________________________________________________ conv3_block2_2_bn (BatchNormali (None, 28, 28, 128) 512 conv3_block2_2_conv[0][0] __________________________________________________________________________________________________ conv3_block2_2_relu (Activation (None, 28, 28, 128) 0 conv3_block2_2_bn[0][0] __________________________________________________________________________________________________ conv3_block2_3_conv (Conv2D) (None, 28, 28, 512) 66048 conv3_block2_2_relu[0][0] __________________________________________________________________________________________________ conv3_block2_3_bn (BatchNormali (None, 28, 28, 512) 2048 conv3_block2_3_conv[0][0] __________________________________________________________________________________________________ conv3_block2_add (Add) (None, 28, 28, 512) 0 conv3_block1_out[0][0] conv3_block2_3_bn[0][0] __________________________________________________________________________________________________ conv3_block2_out (Activation) (None, 28, 28, 512) 0 conv3_block2_add[0][0] __________________________________________________________________________________________________ conv3_block3_1_conv (Conv2D) (None, 28, 28, 128) 65664 conv3_block2_out[0][0] __________________________________________________________________________________________________ conv3_block3_1_bn (BatchNormali (None, 28, 28, 128) 512 conv3_block3_1_conv[0][0] __________________________________________________________________________________________________ conv3_block3_1_relu (Activation (None, 28, 28, 128) 0 conv3_block3_1_bn[0][0] __________________________________________________________________________________________________ conv3_block3_2_conv (Conv2D) (None, 28, 28, 128) 147584 conv3_block3_1_relu[0][0] __________________________________________________________________________________________________ conv3_block3_2_bn (BatchNormali (None, 28, 28, 128) 512 conv3_block3_2_conv[0][0] __________________________________________________________________________________________________ conv3_block3_2_relu (Activation (None, 28, 28, 128) 0 conv3_block3_2_bn[0][0] __________________________________________________________________________________________________ conv3_block3_3_conv (Conv2D) (None, 28, 28, 512) 66048 conv3_block3_2_relu[0][0] __________________________________________________________________________________________________ conv3_block3_3_bn (BatchNormali (None, 28, 28, 512) 2048 conv3_block3_3_conv[0][0] __________________________________________________________________________________________________ conv3_block3_add (Add) (None, 28, 28, 512) 0 conv3_block2_out[0][0] conv3_block3_3_bn[0][0] __________________________________________________________________________________________________ conv3_block3_out (Activation) (None, 28, 28, 512) 0 conv3_block3_add[0][0] __________________________________________________________________________________________________ conv3_block4_1_conv (Conv2D) (None, 28, 28, 128) 65664 conv3_block3_out[0][0] __________________________________________________________________________________________________ conv3_block4_1_bn (BatchNormali (None, 28, 28, 128) 512 conv3_block4_1_conv[0][0] __________________________________________________________________________________________________ conv3_block4_1_relu (Activation (None, 28, 28, 128) 0 conv3_block4_1_bn[0][0] __________________________________________________________________________________________________ conv3_block4_2_conv (Conv2D) (None, 28, 28, 128) 147584 conv3_block4_1_relu[0][0] __________________________________________________________________________________________________ conv3_block4_2_bn (BatchNormali (None, 28, 28, 128) 512 conv3_block4_2_conv[0][0] __________________________________________________________________________________________________ conv3_block4_2_relu (Activation (None, 28, 28, 128) 0 conv3_block4_2_bn[0][0] __________________________________________________________________________________________________ conv3_block4_3_conv (Conv2D) (None, 28, 28, 512) 66048 conv3_block4_2_relu[0][0] __________________________________________________________________________________________________ conv3_block4_3_bn (BatchNormali (None, 28, 28, 512) 2048 conv3_block4_3_conv[0][0] __________________________________________________________________________________________________ conv3_block4_add (Add) (None, 28, 28, 512) 0 conv3_block3_out[0][0] conv3_block4_3_bn[0][0] __________________________________________________________________________________________________ conv3_block4_out (Activation) (None, 28, 28, 512) 0 conv3_block4_add[0][0] __________________________________________________________________________________________________ conv4_block1_1_conv (Conv2D) (None, 14, 14, 256) 131328 conv3_block4_out[0][0] __________________________________________________________________________________________________ conv4_block1_1_bn (BatchNormali (None, 14, 14, 256) 1024 conv4_block1_1_conv[0][0] __________________________________________________________________________________________________ conv4_block1_1_relu (Activation (None, 14, 14, 256) 0 conv4_block1_1_bn[0][0] __________________________________________________________________________________________________ conv4_block1_2_conv (Conv2D) (None, 14, 14, 256) 590080 conv4_block1_1_relu[0][0] __________________________________________________________________________________________________ conv4_block1_2_bn (BatchNormali (None, 14, 14, 256) 1024 conv4_block1_2_conv[0][0] __________________________________________________________________________________________________ conv4_block1_2_relu (Activation (None, 14, 14, 256) 0 conv4_block1_2_bn[0][0] __________________________________________________________________________________________________ conv4_block1_0_conv (Conv2D) (None, 14, 14, 1024) 525312 conv3_block4_out[0][0] __________________________________________________________________________________________________ conv4_block1_3_conv (Conv2D) (None, 14, 14, 1024) 263168 conv4_block1_2_relu[0][0] __________________________________________________________________________________________________ conv4_block1_0_bn (BatchNormali (None, 14, 14, 1024) 4096 conv4_block1_0_conv[0][0] __________________________________________________________________________________________________ conv4_block1_3_bn (BatchNormali (None, 14, 14, 1024) 4096 conv4_block1_3_conv[0][0] __________________________________________________________________________________________________ conv4_block1_add (Add) (None, 14, 14, 1024) 0 conv4_block1_0_bn[0][0] conv4_block1_3_bn[0][0] __________________________________________________________________________________________________ conv4_block1_out (Activation) (None, 14, 14, 1024) 0 conv4_block1_add[0][0] __________________________________________________________________________________________________ conv4_block2_1_conv (Conv2D) (None, 14, 14, 256) 262400 conv4_block1_out[0][0] __________________________________________________________________________________________________ conv4_block2_1_bn (BatchNormali (None, 14, 14, 256) 1024 conv4_block2_1_conv[0][0] __________________________________________________________________________________________________ conv4_block2_1_relu (Activation (None, 14, 14, 256) 0 conv4_block2_1_bn[0][0] __________________________________________________________________________________________________ conv4_block2_2_conv (Conv2D) (None, 14, 14, 256) 590080 conv4_block2_1_relu[0][0] __________________________________________________________________________________________________ conv4_block2_2_bn (BatchNormali (None, 14, 14, 256) 1024 conv4_block2_2_conv[0][0] __________________________________________________________________________________________________ conv4_block2_2_relu (Activation (None, 14, 14, 256) 0 conv4_block2_2_bn[0][0] __________________________________________________________________________________________________ conv4_block2_3_conv (Conv2D) (None, 14, 14, 1024) 263168 conv4_block2_2_relu[0][0] __________________________________________________________________________________________________ conv4_block2_3_bn (BatchNormali (None, 14, 14, 1024) 4096 conv4_block2_3_conv[0][0] __________________________________________________________________________________________________ conv4_block2_add (Add) (None, 14, 14, 1024) 0 conv4_block1_out[0][0] conv4_block2_3_bn[0][0] __________________________________________________________________________________________________ conv4_block2_out (Activation) (None, 14, 14, 1024) 0 conv4_block2_add[0][0] __________________________________________________________________________________________________ conv4_block3_1_conv (Conv2D) (None, 14, 14, 256) 262400 conv4_block2_out[0][0] __________________________________________________________________________________________________ conv4_block3_1_bn (BatchNormali (None, 14, 14, 256) 1024 conv4_block3_1_conv[0][0] __________________________________________________________________________________________________ conv4_block3_1_relu (Activation (None, 14, 14, 256) 0 conv4_block3_1_bn[0][0] __________________________________________________________________________________________________ conv4_block3_2_conv (Conv2D) (None, 14, 14, 256) 590080 conv4_block3_1_relu[0][0] __________________________________________________________________________________________________ conv4_block3_2_bn (BatchNormali (None, 14, 14, 256) 1024 conv4_block3_2_conv[0][0] __________________________________________________________________________________________________ conv4_block3_2_relu (Activation (None, 14, 14, 256) 0 conv4_block3_2_bn[0][0] __________________________________________________________________________________________________ conv4_block3_3_conv (Conv2D) (None, 14, 14, 1024) 263168 conv4_block3_2_relu[0][0] __________________________________________________________________________________________________ conv4_block3_3_bn (BatchNormali (None, 14, 14, 1024) 4096 conv4_block3_3_conv[0][0] __________________________________________________________________________________________________ conv4_block3_add (Add) (None, 14, 14, 1024) 0 conv4_block2_out[0][0] conv4_block3_3_bn[0][0] __________________________________________________________________________________________________ conv4_block3_out (Activation) (None, 14, 14, 1024) 0 conv4_block3_add[0][0] __________________________________________________________________________________________________ conv4_block4_1_conv (Conv2D) (None, 14, 14, 256) 262400 conv4_block3_out[0][0] __________________________________________________________________________________________________ conv4_block4_1_bn (BatchNormali (None, 14, 14, 256) 1024 conv4_block4_1_conv[0][0] __________________________________________________________________________________________________ conv4_block4_1_relu (Activation (None, 14, 14, 256) 0 conv4_block4_1_bn[0][0] __________________________________________________________________________________________________ conv4_block4_2_conv (Conv2D) (None, 14, 14, 256) 590080 conv4_block4_1_relu[0][0] __________________________________________________________________________________________________ conv4_block4_2_bn (BatchNormali (None, 14, 14, 256) 1024 conv4_block4_2_conv[0][0] __________________________________________________________________________________________________ conv4_block4_2_relu (Activation (None, 14, 14, 256) 0 conv4_block4_2_bn[0][0] __________________________________________________________________________________________________ conv4_block4_3_conv (Conv2D) (None, 14, 14, 1024) 263168 conv4_block4_2_relu[0][0] __________________________________________________________________________________________________ conv4_block4_3_bn (BatchNormali (None, 14, 14, 1024) 4096 conv4_block4_3_conv[0][0] __________________________________________________________________________________________________ conv4_block4_add (Add) (None, 14, 14, 1024) 0 conv4_block3_out[0][0] conv4_block4_3_bn[0][0] __________________________________________________________________________________________________ conv4_block4_out (Activation) (None, 14, 14, 1024) 0 conv4_block4_add[0][0] __________________________________________________________________________________________________ conv4_block5_1_conv (Conv2D) (None, 14, 14, 256) 262400 conv4_block4_out[0][0] __________________________________________________________________________________________________ conv4_block5_1_bn (BatchNormali (None, 14, 14, 256) 1024 conv4_block5_1_conv[0][0] __________________________________________________________________________________________________ conv4_block5_1_relu (Activation (None, 14, 14, 256) 0 conv4_block5_1_bn[0][0] __________________________________________________________________________________________________ conv4_block5_2_conv (Conv2D) (None, 14, 14, 256) 590080 conv4_block5_1_relu[0][0] __________________________________________________________________________________________________ conv4_block5_2_bn (BatchNormali (None, 14, 14, 256) 1024 conv4_block5_2_conv[0][0] __________________________________________________________________________________________________ conv4_block5_2_relu (Activation (None, 14, 14, 256) 0 conv4_block5_2_bn[0][0] __________________________________________________________________________________________________ conv4_block5_3_conv (Conv2D) (None, 14, 14, 1024) 263168 conv4_block5_2_relu[0][0] __________________________________________________________________________________________________ conv4_block5_3_bn (BatchNormali (None, 14, 14, 1024) 4096 conv4_block5_3_conv[0][0] __________________________________________________________________________________________________ conv4_block5_add (Add) (None, 14, 14, 1024) 0 conv4_block4_out[0][0] conv4_block5_3_bn[0][0] __________________________________________________________________________________________________ conv4_block5_out (Activation) (None, 14, 14, 1024) 0 conv4_block5_add[0][0] __________________________________________________________________________________________________ conv4_block6_1_conv (Conv2D) (None, 14, 14, 256) 262400 conv4_block5_out[0][0] __________________________________________________________________________________________________ conv4_block6_1_bn (BatchNormali (None, 14, 14, 256) 1024 conv4_block6_1_conv[0][0] __________________________________________________________________________________________________ conv4_block6_1_relu (Activation (None, 14, 14, 256) 0 conv4_block6_1_bn[0][0] __________________________________________________________________________________________________ conv4_block6_2_conv (Conv2D) (None, 14, 14, 256) 590080 conv4_block6_1_relu[0][0] __________________________________________________________________________________________________ conv4_block6_2_bn (BatchNormali (None, 14, 14, 256) 1024 conv4_block6_2_conv[0][0] __________________________________________________________________________________________________ conv4_block6_2_relu (Activation (None, 14, 14, 256) 0 conv4_block6_2_bn[0][0] __________________________________________________________________________________________________ conv4_block6_3_conv (Conv2D) (None, 14, 14, 1024) 263168 conv4_block6_2_relu[0][0] __________________________________________________________________________________________________ conv4_block6_3_bn (BatchNormali (None, 14, 14, 1024) 4096 conv4_block6_3_conv[0][0] __________________________________________________________________________________________________ conv4_block6_add (Add) (None, 14, 14, 1024) 0 conv4_block5_out[0][0] conv4_block6_3_bn[0][0] __________________________________________________________________________________________________ conv4_block6_out (Activation) (None, 14, 14, 1024) 0 conv4_block6_add[0][0] __________________________________________________________________________________________________ conv5_block1_1_conv (Conv2D) (None, 7, 7, 512) 524800 conv4_block6_out[0][0] __________________________________________________________________________________________________ conv5_block1_1_bn (BatchNormali (None, 7, 7, 512) 2048 conv5_block1_1_conv[0][0] __________________________________________________________________________________________________ conv5_block1_1_relu (Activation (None, 7, 7, 512) 0 conv5_block1_1_bn[0][0] __________________________________________________________________________________________________ conv5_block1_2_conv (Conv2D) (None, 7, 7, 512) 2359808 conv5_block1_1_relu[0][0] __________________________________________________________________________________________________ conv5_block1_2_bn (BatchNormali (None, 7, 7, 512) 2048 conv5_block1_2_conv[0][0] __________________________________________________________________________________________________ conv5_block1_2_relu (Activation (None, 7, 7, 512) 0 conv5_block1_2_bn[0][0] __________________________________________________________________________________________________ conv5_block1_0_conv (Conv2D) (None, 7, 7, 2048) 2099200 conv4_block6_out[0][0] __________________________________________________________________________________________________ conv5_block1_3_conv (Conv2D) (None, 7, 7, 2048) 1050624 conv5_block1_2_relu[0][0] __________________________________________________________________________________________________ conv5_block1_0_bn (BatchNormali (None, 7, 7, 2048) 8192 conv5_block1_0_conv[0][0] __________________________________________________________________________________________________ conv5_block1_3_bn (BatchNormali (None, 7, 7, 2048) 8192 conv5_block1_3_conv[0][0] __________________________________________________________________________________________________ conv5_block1_add (Add) (None, 7, 7, 2048) 0 conv5_block1_0_bn[0][0] conv5_block1_3_bn[0][0] __________________________________________________________________________________________________ conv5_block1_out (Activation) (None, 7, 7, 2048) 0 conv5_block1_add[0][0] __________________________________________________________________________________________________ conv5_block2_1_conv (Conv2D) (None, 7, 7, 512) 1049088 conv5_block1_out[0][0] __________________________________________________________________________________________________ conv5_block2_1_bn (BatchNormali (None, 7, 7, 512) 2048 conv5_block2_1_conv[0][0] __________________________________________________________________________________________________ conv5_block2_1_relu (Activation (None, 7, 7, 512) 0 conv5_block2_1_bn[0][0] __________________________________________________________________________________________________ conv5_block2_2_conv (Conv2D) (None, 7, 7, 512) 2359808 conv5_block2_1_relu[0][0] __________________________________________________________________________________________________ conv5_block2_2_bn (BatchNormali (None, 7, 7, 512) 2048 conv5_block2_2_conv[0][0] __________________________________________________________________________________________________ conv5_block2_2_relu (Activation (None, 7, 7, 512) 0 conv5_block2_2_bn[0][0] __________________________________________________________________________________________________ conv5_block2_3_conv (Conv2D) (None, 7, 7, 2048) 1050624 conv5_block2_2_relu[0][0] __________________________________________________________________________________________________ conv5_block2_3_bn (BatchNormali (None, 7, 7, 2048) 8192 conv5_block2_3_conv[0][0] __________________________________________________________________________________________________ conv5_block2_add (Add) (None, 7, 7, 2048) 0 conv5_block1_out[0][0] conv5_block2_3_bn[0][0] __________________________________________________________________________________________________ conv5_block2_out (Activation) (None, 7, 7, 2048) 0 conv5_block2_add[0][0] __________________________________________________________________________________________________ conv5_block3_1_conv (Conv2D) (None, 7, 7, 512) 1049088 conv5_block2_out[0][0] __________________________________________________________________________________________________ conv5_block3_1_bn (BatchNormali (None, 7, 7, 512) 2048 conv5_block3_1_conv[0][0] __________________________________________________________________________________________________ conv5_block3_1_relu (Activation (None, 7, 7, 512) 0 conv5_block3_1_bn[0][0] __________________________________________________________________________________________________ conv5_block3_2_conv (Conv2D) (None, 7, 7, 512) 2359808 conv5_block3_1_relu[0][0] __________________________________________________________________________________________________ conv5_block3_2_bn (BatchNormali (None, 7, 7, 512) 2048 conv5_block3_2_conv[0][0] __________________________________________________________________________________________________ conv5_block3_2_relu (Activation (None, 7, 7, 512) 0 conv5_block3_2_bn[0][0] __________________________________________________________________________________________________ conv5_block3_3_conv (Conv2D) (None, 7, 7, 2048) 1050624 conv5_block3_2_relu[0][0] __________________________________________________________________________________________________ conv5_block3_3_bn (BatchNormali (None, 7, 7, 2048) 8192 conv5_block3_3_conv[0][0] __________________________________________________________________________________________________ conv5_block3_add (Add) (None, 7, 7, 2048) 0 conv5_block2_out[0][0] conv5_block3_3_bn[0][0] __________________________________________________________________________________________________ conv5_block3_out (Activation) (None, 7, 7, 2048) 0 conv5_block3_add[0][0] __________________________________________________________________________________________________ avg_pool (GlobalAveragePooling2 (None, 2048) 0 conv5_block3_out[0][0] __________________________________________________________________________________________________ predictions (Dense) (None, 1000) 2049000 avg_pool[0][0] ================================================================================================== Total params: 25,636,712 Trainable params: 25,583,592 Non-trainable params: 53,120 __________________________________________________________________________________________________
As you can see, the model contains the same familiar building blocks: convolutional layers, pooling layers and final dense classifier. We can use this model in exactly the same manner as we have been using VGG-16 for transfer learning. You can try experimenting with the code above, using different ResNet models as the base model, and see how accuracy changes.
This network contains yet another type of layer: Batch Normalization. The idea of batch normalization is to bring values that flow through the neural network to right interval. Usually neural networks work best when all values are in the range of [-1,1] or [0,1], and that is the reason that we scale/normalize our input data accordingly. However, during training of a deep network, it can happen that values get significantly out of this range, which makes training problematic. Batch normalization layer computes average and standard deviation for all values of the current minibatch, and uses them to normalize the signal before passing it through a neural network layer. This significantly improves the stability of deep networks.
Using transfer learning, we were able to quickly put together a classifier for our custom object classification task, and achieve high accuracy. However, this example was not completely fair, because original VGG-16 network was pre-trained to recognize cats and dogs, and thus we were just reusing most of the patterns that were already present in the network. You can expect lower accuracy on more exotic domain-specific objects, such as details on production line in a plant, or different tree leaves.
You can see that more complex tasks that we are solving now require higher computational power, and cannot be easily solved on the CPU. In the next unit, we will try to use more lightweight implementation to train the same model using lower compute resources, which results in just slightly lower accuracy.