In [1]:
# code for loading the format for the notebook
import os

# path : store the current path to convert back to it later
path = os.getcwd()
os.chdir(os.path.join('..', '..', 'notebook_format'))

from formats import load_style
load_style(plot_style=False)
Out[1]:
In [2]:
os.chdir(path)

import cv2
import numpy as np
import matplotlib.pyplot as plt
import keras.layers as layers
from keras.models import Model
from keras.preprocessing import image

# 1. magic so that the notebook will reload external python modules
# 2. magic for inline plot
# 3. magic to enable retina (high resolution) plots
# https://gist.github.com/minrk/3301035
# 4. magic to print version
%load_ext autoreload
%autoreload 2
%matplotlib inline
%config InlineBackend.figure_format='retina'
%load_ext watermark

%watermark -a 'Ethen' -d -t -v -p cv2,keras,numpy,matplotlib,tensorflow
Using TensorFlow backend.
Ethen 2018-09-04 20:57:58 

CPython 3.6.4
IPython 6.4.0

cv2 3.4.0
keras 2.2.2
numpy 1.14.1
matplotlib 2.2.2
tensorflow 1.7.0

ResNet (Residual Network)

BackGround

In the field of deep learning, often times we would like to add more nodes or more layers to our model to increase the its learning capability (introduce other methods to prevent overfitting as well of course). But the question now arises, is better network as easy as stacking more layers? Or is there a limit in how deep should we go.

As you the imagine, the reason why we're raising this question is because in practice, when we start pushing the "deep-ness" of our network to its extreme, we do observe that making the network too deep can actually be detrimental. The image below shows the training and testing error of a 20-layer and 56-layer convnet. As we can see from the chart, a deeper network actually resulted in a higher training error.

This indicates not all networks can be optimized in the same fashion.

In this notebook, we will be introducing ResNet (Residual Network). When it was unleashed in 2015, this type of network won all image classification, detection, and localization challenges of ImageNet. As depicted in the chart below, this type of network can benefit from increasing the depth of the network and enjoy the accuracy gain, while the counterpart "plain" network degrades in performance when we increases its depth.

Introduction

Before jumping straight into ResNets let's do a quick recap on the concept of a residual.

Residual is the error in a result. Let's say, we were asked to predict the age of a person, just by looking at the person. If his/her actual age is 20, and we predicted 18, we are off by 2, and our residual in this case would be 2. If we had predicted 21, we would have been off by -1. In essence, residual is what we should have added to our prediction to match the actual/expected value.

What is important to understand here is that, if the residual is 0, we should not be performing any action since the prediction already matched the actual output. This idea can be depicted with the following diagram.

In the diagram above, x is our prediction and we want it to be equal to Actual. When is it off by a margin, our residual function residual() will kick in and tries to correct our prediction to match the actual. On the other hand, when x == Actual, residual(x) will be 0. And the Identity function just copies x as is.

Now back to ResNets. This architecture's main highlight is the use of residual blocks. a.k.a skip-connections, shortcuts.

Let us consider $H(x)$ as an underlying mapping to be fit by a few stacked layers (not necessarily the entire net), with $x$ denoting the inputs to the first of these layers. If one hypothesizes that multiple nonlinear layers can asymptotically approximate complicated functions, then it is equivalent to hypothesize that they can asymptotically approximate the residual functions, i.e., $H(x) − x$ (assuming the input and output are of the same dimensions). So rather than expecting our stacked layers to approximate $H(x)$, we explicitly let these layers approximate a residual function $F(x) := H(x) − x$. The original function thus becomes $F(x) + x$.

The paragraph from the original paper can be summarized into the following diagram:

In the diagram above, we have our main path on the left, which consists of multiple nonlinear layers. And on the right, we have our shortcut connections, which skips one or more layers and perform identity mapping (i.e. copying the original input), and their outputs are added to the outputs of the stacked layers.

The idea is that behind this design is that during training, if the residual network learns the identity mapping were optimal weights, all the solver can push the weights of the multiple nonlinear layers to 0. This means no corrections need to be made, our $F(x)$ is essentially set to 0, and the network can use the shortcut path to perform an identity mapping. In real cases, it is unlikely that identity mappings are optimal, but this formulation may help precondition the problem. Meaning it should be easier for the solver to find a more suitable weight with reference to an identity mapping, than to learn a completely new one.

Identity Block

Let's take a look at how we can implement this skip-connection with keras.

The identity block is one of the standard building blocks of ResNets.

The idea is we have a "main path" (the lower path) and the "shortcut path", the identity block here skips over three layers (each layer is composed of a convolutional, a batchnorm layer followed by a relu activation). And right before the end of third layer, i.e. before its relu activation, we add the shortcut path back to the main path.

In [3]:
from keras.initializers import glorot_uniform


def identity_block(input_tensor, kernel_size, filters, stage, block):
    """ 
    An identity block.

    Parameters
    ----------
    input_tensor:

    kernel_size: int
        The kernel size of middle conv layer at main path.

    filters: list[int]
        The filters of 3 conv layer at main path.

    stage: int
        Current stage label, used for generating layer names.

    block: : str
        'a','b'..., current block label, used for generating layer names.

    Returns
    -------
    Output tensor for the block.
    """
    conv_name_base = 'res' + str(stage) + block + '_branch'
    bn_name_base = 'bn' + str(stage) + block + '_branch'
    
    # for batch normalization layer, we assume
    # the input data is in channel last format
    bn_axis = 3

    filters1, filters2, filters3 = filters
  
    # main path, note that setting the kernel_initializer seed here is only used
    # for reproducibility, we techniqually don't need it
    x = layers.Conv2D(filters1, kernel_size=(1, 1), strides=(1, 1),
                      kernel_initializer=glorot_uniform(seed=0),
                      padding='valid', name=conv_name_base + '2a')(input_tensor)
    x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x)
    x = layers.Activation('relu')(x)

    x = layers.Conv2D(filters2, kernel_size, strides=(1, 1),
                      kernel_initializer=glorot_uniform(seed=0),
                      padding='same', name=conv_name_base + '2b')(x)
    x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x)
    x = layers.Activation('relu')(x)

    x = layers.Conv2D(filters3, kernel_size=(1, 1), strides=(1, 1),
                      kernel_initializer=glorot_uniform(seed=0),
                      padding='valid', name=conv_name_base + '2c')(x)
    x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x)

    # this line is the core component of resnet, the skip connection, i.e.
    # having a shortcut to the main path before the activation, when addition
    # is performed on convolutional layers, the element-wise addition is performed
    # on their feature maps, i.e. channel by channel
    x = layers.add([x, input_tensor])
    x = layers.Activation('relu')(x)
    return x
In [4]:
# generate some fake data to work with
np.random.seed(0)
X = np.random.randn(3, 4, 4, 6)
print('original data shape:', X.shape)
  

stage = 1
block = 'a'
inputs = layers.Input(shape=X.shape[1:])

outputs = identity_block(inputs, kernel_size=2, filters=[2, 4, 6], stage=stage, block=block)

model = Model(inputs=inputs, outputs=outputs)
prediction = model.predict(X)
print('identity block output shape:', prediction.shape)
prediction[1, 1, 0]
original data shape: (3, 4, 4, 6)
identity block output shape: (3, 4, 4, 6)
Out[4]:
array([0.7089733 , 0.        , 1.1173227 , 1.5701073 , 0.8489223 ,
       0.05919211], dtype=float32)

Convolutional Block

The convolutional block is another type of block for ResNet block, and is used when the input and output dimension doesn't match up when we change the channel size. For example, to reduce the activation dimensions's height and width by a factor of 2, we can use a $1 \times 1$ convolution with a stride of 2.

The CONV2D layer on the shortcut path does not use any non-linear activation function. Its main role is to apply a (learned) linear function that reduces the dimension of the input, so that the dimensions match up for the later addition step.

In [5]:
def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2)):
    """ 
    A block that has a conv layer at shortcut.

    Parameters
    ----------
    input_tensor:
    
    kernel_size: int
        The kernel size of middle conv layer at main path.

    filters: list[int]
        The filters of 3 conv layer at main path.

    stage: int
        Current stage label, used for generating layer names.

    block: : str
        'a','b'..., current block label, used for generating layer names.
        
    strides : tuple, default (2, 2)
        Strides for the first conv layer in the block.

    Returns
    -------
    Output tensor for the block.
    """
    conv_name_base = 'res' + str(stage) + block + '_branch'
    bn_name_base = 'bn' + str(stage) + block + '_branch'
    
    # for batch normalization layer, we assume
    # the input data is in channel last format,
    # which is the case if we are using the default
    # keras' backend tensorflow
    bn_axis = 3

    filters1, filters2, filters3 = filters
  
    # main path, note that setting the kernel_initializer set here is only used
    # for reproducibility, we techniqually don't need it
    x = layers.Conv2D(filters1, kernel_size=(1, 1), strides=strides,
                      kernel_initializer=glorot_uniform(seed=0),
                      padding='valid', name=conv_name_base + '2a')(input_tensor)
    x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x)
    x = layers.Activation('relu')(x)

    x = layers.Conv2D(filters2, kernel_size, strides=(1, 1),
                      kernel_initializer=glorot_uniform(seed=0),
                      padding='same', name=conv_name_base + '2b')(x)
    x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x)
    x = layers.Activation('relu')(x)

    x = layers.Conv2D(filters3, kernel_size=(1, 1), strides=(1, 1),
                      kernel_initializer=glorot_uniform(seed=0),
                      padding='valid', name=conv_name_base + '2c')(x)
    x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x)
    
    # we resize the input so its dimension will match the output dimension
    # of the main path
    shortcut = layers.Conv2D(filters3, kernel_size=(1, 1), strides=strides,
                             kernel_initializer=glorot_uniform(seed=0),
                             padding='valid', name=conv_name_base + '1')(input_tensor)
    shortcut = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '1')(shortcut) 

    # this line is the core component of resnet, the skip connection, i.e.
    # having a shortcut to the main path before the activation
    x = layers.add([x, shortcut])
    x = layers.Activation('relu')(x)
    return x
In [6]:
stage = 1
block = 'a'
inputs = layers.Input(shape=X.shape[1:])

outputs = conv_block(inputs, kernel_size=2, filters=[2, 4, 6], stage=stage, block=block)

model = Model(inputs=inputs, outputs=outputs)
prediction = model.predict(X)
print('identity block output shape:', prediction.shape)
prediction[1, 1, 0]
identity block output shape: (3, 2, 2, 6)
Out[6]:
array([0.        , 1.0165374 , 0.        , 0.        , 0.        ,
       0.48622143], dtype=float32)

ResNet In Action

Now that we have a basic understanding of the definition of a ResNet, we will build one a train it on the MNIST dataset.

In [7]:
from keras.datasets import mnist
from keras.utils import np_utils


(X_train, y_train), (X_test, y_test) = mnist.load_data()
print('X_train shape:', X_train.shape)
print(X_train.shape[0], 'train samples')
print(X_test.shape[0] , 'test samples')
X_train shape: (60000, 28, 28)
60000 train samples
10000 test samples
In [8]:
n_classes = 10
img_rows, img_cols = 28, 28

# mnist is grey-scaled image, thus the last dimension, channel size will be 1
X_train = X_train.reshape(X_train.shape[0], img_rows, img_cols, 1)
X_test = X_test.reshape(X_test.shape[0], img_rows, img_cols, 1)
input_shape = img_rows, img_cols, 1

X_train = X_train.astype('float32')
X_test  = X_test.astype('float32')

# images takes values between 0 - 255, we can normalize it
# by dividing every number by 255
X_train /= 255
X_test /= 255
print('train shape:', X_train.shape)

# one-hot encode the class (target) vectors
Y_train = np_utils.to_categorical(y_train, n_classes)
Y_test = np_utils.to_categorical(y_test , n_classes)
print('y_train shape:', Y_train.shape)
train shape: (60000, 28, 28, 1)
y_train shape: (60000, 10)
In [9]:
def ResNet(input_shape, n_classes):
    """
    Definition of ResNet
    
    References
    ----------
    https://github.com/keras-team/keras-applications/blob/master/keras_applications/resnet50.py
    """
    img_input = layers.Input(shape=input_shape)
    
    bn_axis = 3
    
    x = layers.ZeroPadding2D(padding=(3, 3), name='conv1_pad')(img_input)
    x = layers.Conv2D(64, (7, 7),
                      strides=(2, 2),
                      padding='valid',
                      kernel_initializer='he_normal',
                      name='conv1')(x)
    x = layers.BatchNormalization(axis=bn_axis, name='bn_conv1')(x)
    x = layers.Activation('relu')(x)
    x = layers.ZeroPadding2D(padding=(1, 1), name='pool1_pad')(x)
    x = layers.MaxPooling2D((3, 3), strides=(2, 2))(x)

    x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1))
    x = identity_block(x, 3, [64, 64, 256], stage=2, block='b')
    x = identity_block(x, 3, [64, 64, 256], stage=2, block='c')

    x = conv_block(x, 3, [128, 128, 512], stage=3, block='a')
    x = identity_block(x, 3, [128, 128, 512], stage=3, block='b')
    x = identity_block(x, 3, [128, 128, 512], stage=3, block='c')
    x = identity_block(x, 3, [128, 128, 512], stage=3, block='d')
    
    # the commented out blocks are what's needed to build out the
    # full ResNet50 (a ResNet with 50 layers), we won't be needing
    # the complexity here
    # x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a')
    # x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b')
    # x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c')
    # x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d')
    # x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e')
    # x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f')

    # x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a')
    # x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b')
    # x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c')
    
    x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
    img_output = layers.Dense(n_classes, activation='softmax', name='fc' + str(n_classes))(x)

    model = Model(inputs=img_input, outputs=img_output, name='resnet')
    return model


model = ResNet(input_shape, n_classes)
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
model.summary()
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_3 (InputLayer)            (None, 28, 28, 1)    0                                            
__________________________________________________________________________________________________
conv1_pad (ZeroPadding2D)       (None, 34, 34, 1)    0           input_3[0][0]                    
__________________________________________________________________________________________________
conv1 (Conv2D)                  (None, 14, 14, 64)   3200        conv1_pad[0][0]                  
__________________________________________________________________________________________________
bn_conv1 (BatchNormalization)   (None, 14, 14, 64)   256         conv1[0][0]                      
__________________________________________________________________________________________________
activation_7 (Activation)       (None, 14, 14, 64)   0           bn_conv1[0][0]                   
__________________________________________________________________________________________________
pool1_pad (ZeroPadding2D)       (None, 16, 16, 64)   0           activation_7[0][0]               
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D)  (None, 7, 7, 64)     0           pool1_pad[0][0]                  
__________________________________________________________________________________________________
res2a_branch2a (Conv2D)         (None, 7, 7, 64)     4160        max_pooling2d_1[0][0]            
__________________________________________________________________________________________________
bn2a_branch2a (BatchNormalizati (None, 7, 7, 64)     256         res2a_branch2a[0][0]             
__________________________________________________________________________________________________
activation_8 (Activation)       (None, 7, 7, 64)     0           bn2a_branch2a[0][0]              
__________________________________________________________________________________________________
res2a_branch2b (Conv2D)         (None, 7, 7, 64)     36928       activation_8[0][0]               
__________________________________________________________________________________________________
bn2a_branch2b (BatchNormalizati (None, 7, 7, 64)     256         res2a_branch2b[0][0]             
__________________________________________________________________________________________________
activation_9 (Activation)       (None, 7, 7, 64)     0           bn2a_branch2b[0][0]              
__________________________________________________________________________________________________
res2a_branch2c (Conv2D)         (None, 7, 7, 256)    16640       activation_9[0][0]               
__________________________________________________________________________________________________
res2a_branch1 (Conv2D)          (None, 7, 7, 256)    16640       max_pooling2d_1[0][0]            
__________________________________________________________________________________________________
bn2a_branch2c (BatchNormalizati (None, 7, 7, 256)    1024        res2a_branch2c[0][0]             
__________________________________________________________________________________________________
bn2a_branch1 (BatchNormalizatio (None, 7, 7, 256)    1024        res2a_branch1[0][0]              
__________________________________________________________________________________________________
add_3 (Add)                     (None, 7, 7, 256)    0           bn2a_branch2c[0][0]              
                                                                 bn2a_branch1[0][0]               
__________________________________________________________________________________________________
activation_10 (Activation)      (None, 7, 7, 256)    0           add_3[0][0]                      
__________________________________________________________________________________________________
res2b_branch2a (Conv2D)         (None, 7, 7, 64)     16448       activation_10[0][0]              
__________________________________________________________________________________________________
bn2b_branch2a (BatchNormalizati (None, 7, 7, 64)     256         res2b_branch2a[0][0]             
__________________________________________________________________________________________________
activation_11 (Activation)      (None, 7, 7, 64)     0           bn2b_branch2a[0][0]              
__________________________________________________________________________________________________
res2b_branch2b (Conv2D)         (None, 7, 7, 64)     36928       activation_11[0][0]              
__________________________________________________________________________________________________
bn2b_branch2b (BatchNormalizati (None, 7, 7, 64)     256         res2b_branch2b[0][0]             
__________________________________________________________________________________________________
activation_12 (Activation)      (None, 7, 7, 64)     0           bn2b_branch2b[0][0]              
__________________________________________________________________________________________________
res2b_branch2c (Conv2D)         (None, 7, 7, 256)    16640       activation_12[0][0]              
__________________________________________________________________________________________________
bn2b_branch2c (BatchNormalizati (None, 7, 7, 256)    1024        res2b_branch2c[0][0]             
__________________________________________________________________________________________________
add_4 (Add)                     (None, 7, 7, 256)    0           bn2b_branch2c[0][0]              
                                                                 activation_10[0][0]              
__________________________________________________________________________________________________
activation_13 (Activation)      (None, 7, 7, 256)    0           add_4[0][0]                      
__________________________________________________________________________________________________
res2c_branch2a (Conv2D)         (None, 7, 7, 64)     16448       activation_13[0][0]              
__________________________________________________________________________________________________
bn2c_branch2a (BatchNormalizati (None, 7, 7, 64)     256         res2c_branch2a[0][0]             
__________________________________________________________________________________________________
activation_14 (Activation)      (None, 7, 7, 64)     0           bn2c_branch2a[0][0]              
__________________________________________________________________________________________________
res2c_branch2b (Conv2D)         (None, 7, 7, 64)     36928       activation_14[0][0]              
__________________________________________________________________________________________________
bn2c_branch2b (BatchNormalizati (None, 7, 7, 64)     256         res2c_branch2b[0][0]             
__________________________________________________________________________________________________
activation_15 (Activation)      (None, 7, 7, 64)     0           bn2c_branch2b[0][0]              
__________________________________________________________________________________________________
res2c_branch2c (Conv2D)         (None, 7, 7, 256)    16640       activation_15[0][0]              
__________________________________________________________________________________________________
bn2c_branch2c (BatchNormalizati (None, 7, 7, 256)    1024        res2c_branch2c[0][0]             
__________________________________________________________________________________________________
add_5 (Add)                     (None, 7, 7, 256)    0           bn2c_branch2c[0][0]              
                                                                 activation_13[0][0]              
__________________________________________________________________________________________________
activation_16 (Activation)      (None, 7, 7, 256)    0           add_5[0][0]                      
__________________________________________________________________________________________________
res3a_branch2a (Conv2D)         (None, 4, 4, 128)    32896       activation_16[0][0]              
__________________________________________________________________________________________________
bn3a_branch2a (BatchNormalizati (None, 4, 4, 128)    512         res3a_branch2a[0][0]             
__________________________________________________________________________________________________
activation_17 (Activation)      (None, 4, 4, 128)    0           bn3a_branch2a[0][0]              
__________________________________________________________________________________________________
res3a_branch2b (Conv2D)         (None, 4, 4, 128)    147584      activation_17[0][0]              
__________________________________________________________________________________________________
bn3a_branch2b (BatchNormalizati (None, 4, 4, 128)    512         res3a_branch2b[0][0]             
__________________________________________________________________________________________________
activation_18 (Activation)      (None, 4, 4, 128)    0           bn3a_branch2b[0][0]              
__________________________________________________________________________________________________
res3a_branch2c (Conv2D)         (None, 4, 4, 512)    66048       activation_18[0][0]              
__________________________________________________________________________________________________
res3a_branch1 (Conv2D)          (None, 4, 4, 512)    131584      activation_16[0][0]              
__________________________________________________________________________________________________
bn3a_branch2c (BatchNormalizati (None, 4, 4, 512)    2048        res3a_branch2c[0][0]             
__________________________________________________________________________________________________
bn3a_branch1 (BatchNormalizatio (None, 4, 4, 512)    2048        res3a_branch1[0][0]              
__________________________________________________________________________________________________
add_6 (Add)                     (None, 4, 4, 512)    0           bn3a_branch2c[0][0]              
                                                                 bn3a_branch1[0][0]               
__________________________________________________________________________________________________
activation_19 (Activation)      (None, 4, 4, 512)    0           add_6[0][0]                      
__________________________________________________________________________________________________
res3b_branch2a (Conv2D)         (None, 4, 4, 128)    65664       activation_19[0][0]              
__________________________________________________________________________________________________
bn3b_branch2a (BatchNormalizati (None, 4, 4, 128)    512         res3b_branch2a[0][0]             
__________________________________________________________________________________________________
activation_20 (Activation)      (None, 4, 4, 128)    0           bn3b_branch2a[0][0]              
__________________________________________________________________________________________________
res3b_branch2b (Conv2D)         (None, 4, 4, 128)    147584      activation_20[0][0]              
__________________________________________________________________________________________________
bn3b_branch2b (BatchNormalizati (None, 4, 4, 128)    512         res3b_branch2b[0][0]             
__________________________________________________________________________________________________
activation_21 (Activation)      (None, 4, 4, 128)    0           bn3b_branch2b[0][0]              
__________________________________________________________________________________________________
res3b_branch2c (Conv2D)         (None, 4, 4, 512)    66048       activation_21[0][0]              
__________________________________________________________________________________________________
bn3b_branch2c (BatchNormalizati (None, 4, 4, 512)    2048        res3b_branch2c[0][0]             
__________________________________________________________________________________________________
add_7 (Add)                     (None, 4, 4, 512)    0           bn3b_branch2c[0][0]              
                                                                 activation_19[0][0]              
__________________________________________________________________________________________________
activation_22 (Activation)      (None, 4, 4, 512)    0           add_7[0][0]                      
__________________________________________________________________________________________________
res3c_branch2a (Conv2D)         (None, 4, 4, 128)    65664       activation_22[0][0]              
__________________________________________________________________________________________________
bn3c_branch2a (BatchNormalizati (None, 4, 4, 128)    512         res3c_branch2a[0][0]             
__________________________________________________________________________________________________
activation_23 (Activation)      (None, 4, 4, 128)    0           bn3c_branch2a[0][0]              
__________________________________________________________________________________________________
res3c_branch2b (Conv2D)         (None, 4, 4, 128)    147584      activation_23[0][0]              
__________________________________________________________________________________________________
bn3c_branch2b (BatchNormalizati (None, 4, 4, 128)    512         res3c_branch2b[0][0]             
__________________________________________________________________________________________________
activation_24 (Activation)      (None, 4, 4, 128)    0           bn3c_branch2b[0][0]              
__________________________________________________________________________________________________
res3c_branch2c (Conv2D)         (None, 4, 4, 512)    66048       activation_24[0][0]              
__________________________________________________________________________________________________
bn3c_branch2c (BatchNormalizati (None, 4, 4, 512)    2048        res3c_branch2c[0][0]             
__________________________________________________________________________________________________
add_8 (Add)                     (None, 4, 4, 512)    0           bn3c_branch2c[0][0]              
                                                                 activation_22[0][0]              
__________________________________________________________________________________________________
activation_25 (Activation)      (None, 4, 4, 512)    0           add_8[0][0]                      
__________________________________________________________________________________________________
res3d_branch2a (Conv2D)         (None, 4, 4, 128)    65664       activation_25[0][0]              
__________________________________________________________________________________________________
bn3d_branch2a (BatchNormalizati (None, 4, 4, 128)    512         res3d_branch2a[0][0]             
__________________________________________________________________________________________________
activation_26 (Activation)      (None, 4, 4, 128)    0           bn3d_branch2a[0][0]              
__________________________________________________________________________________________________
res3d_branch2b (Conv2D)         (None, 4, 4, 128)    147584      activation_26[0][0]              
__________________________________________________________________________________________________
bn3d_branch2b (BatchNormalizati (None, 4, 4, 128)    512         res3d_branch2b[0][0]             
__________________________________________________________________________________________________
activation_27 (Activation)      (None, 4, 4, 128)    0           bn3d_branch2b[0][0]              
__________________________________________________________________________________________________
res3d_branch2c (Conv2D)         (None, 4, 4, 512)    66048       activation_27[0][0]              
__________________________________________________________________________________________________
bn3d_branch2c (BatchNormalizati (None, 4, 4, 512)    2048        res3d_branch2c[0][0]             
__________________________________________________________________________________________________
add_9 (Add)                     (None, 4, 4, 512)    0           bn3d_branch2c[0][0]              
                                                                 activation_25[0][0]              
__________________________________________________________________________________________________
activation_28 (Activation)      (None, 4, 4, 512)    0           add_9[0][0]                      
__________________________________________________________________________________________________
avg_pool (GlobalAveragePooling2 (None, 512)          0           activation_28[0][0]              
__________________________________________________________________________________________________
fc10 (Dense)                    (None, 10)           5130        avg_pool[0][0]                   
==================================================================================================
Total params: 1,458,954
Trainable params: 1,448,842
Non-trainable params: 10,112
__________________________________________________________________________________________________
In [10]:
history = model.fit(X_train, Y_train, epochs=3, batch_size=32)
Epoch 1/3
60000/60000 [==============================] - 168s 3ms/step - loss: 0.1413 - acc: 0.9580
Epoch 2/3
60000/60000 [==============================] - 149s 2ms/step - loss: 0.0635 - acc: 0.9816
Epoch 3/3
60000/60000 [==============================] - 149s 2ms/step - loss: 0.0489 - acc: 0.9859
In [11]:
loss, accuracy = model.evaluate(X_test, Y_test)
print('Loss = ' + str(loss))
print('Test Accuracy = ' + str(accuracy))
10000/10000 [==============================] - 5s 536us/step
Loss = 0.044752538148895835
Test Accuracy = 0.9876

Takeaways:

  • Very deep "plain" networks don't work in practice because they are hard to train due to vanishing gradients.
  • Residual blocks, or so called skip-connections aims to address this vanishing gradient issue by making it easier for a network to learn an identity function.
  • There are two main type of blocks: The identity block and the convolutional block and very deep Residual Networks are built by stacking these blocks together. The latter is used to match the shortcut's dimension and the main path's dimension so the two paths can perform the element-wise addition.

CAM (Class Activation Map)

When working with machine learning models, two common questions that we would love to address is 1. prevent overfitting, 2. being able to provide an explanation of why my model generated the prediction.

In the context of convnets, one way to minimizes the chance of overfitting is the use of Global Average Pooling (GAP). Similar to max pooling layers, GAP layers are used to reduce the spatial dimensions. However, GAP layers perform a more extreme type of dimensionality reduction, where a tensor with dimensions $h \times w \times d$ is reduced to take the size of $1 \times 1 \times d$. i.e. GAP layers reduce each $h \times w$ feature map to a single number by taking the average of all $hw$ values. This operation can be depicted as the following diagram.

The other interesting property about GAP is that its advantages extend beyond acting as a regularization control over our model. With a little bit of tweaking, it allows us to identify exactly which regions of an image are being used for making the prediction towards the predicted class. That is, it will not only tells us what object is contained in the image, but also tells us where the object is in the image. The localization is presented as a heat map (referred to as class activation map from the original paper), where the color-coding scheme identifies regions that are relatively important for the network to perform the object identification task. The following snapshot shows this localization on some sample images:

This approach of highlighting which regions of an image are important to make the classification provides another view of interpreting the inner workings of our convnets.

We will leverage the pre-trained ResNet50 model from Keras to see CAM in action.

In [12]:
from keras.applications.resnet50 import ResNet50, preprocess_input


model = ResNet50(weights='imagenet')
model.summary()
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_4 (InputLayer)            (None, 224, 224, 3)  0                                            
__________________________________________________________________________________________________
conv1_pad (ZeroPadding2D)       (None, 230, 230, 3)  0           input_4[0][0]                    
__________________________________________________________________________________________________
conv1 (Conv2D)                  (None, 112, 112, 64) 9472        conv1_pad[0][0]                  
__________________________________________________________________________________________________
bn_conv1 (BatchNormalization)   (None, 112, 112, 64) 256         conv1[0][0]                      
__________________________________________________________________________________________________
activation_29 (Activation)      (None, 112, 112, 64) 0           bn_conv1[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_2 (MaxPooling2D)  (None, 55, 55, 64)   0           activation_29[0][0]              
__________________________________________________________________________________________________
res2a_branch2a (Conv2D)         (None, 55, 55, 64)   4160        max_pooling2d_2[0][0]            
__________________________________________________________________________________________________
bn2a_branch2a (BatchNormalizati (None, 55, 55, 64)   256         res2a_branch2a[0][0]             
__________________________________________________________________________________________________
activation_30 (Activation)      (None, 55, 55, 64)   0           bn2a_branch2a[0][0]              
__________________________________________________________________________________________________
res2a_branch2b (Conv2D)         (None, 55, 55, 64)   36928       activation_30[0][0]              
__________________________________________________________________________________________________
bn2a_branch2b (BatchNormalizati (None, 55, 55, 64)   256         res2a_branch2b[0][0]             
__________________________________________________________________________________________________
activation_31 (Activation)      (None, 55, 55, 64)   0           bn2a_branch2b[0][0]              
__________________________________________________________________________________________________
res2a_branch2c (Conv2D)         (None, 55, 55, 256)  16640       activation_31[0][0]              
__________________________________________________________________________________________________
res2a_branch1 (Conv2D)          (None, 55, 55, 256)  16640       max_pooling2d_2[0][0]            
__________________________________________________________________________________________________
bn2a_branch2c (BatchNormalizati (None, 55, 55, 256)  1024        res2a_branch2c[0][0]             
__________________________________________________________________________________________________
bn2a_branch1 (BatchNormalizatio (None, 55, 55, 256)  1024        res2a_branch1[0][0]              
__________________________________________________________________________________________________
add_10 (Add)                    (None, 55, 55, 256)  0           bn2a_branch2c[0][0]              
                                                                 bn2a_branch1[0][0]               
__________________________________________________________________________________________________
activation_32 (Activation)      (None, 55, 55, 256)  0           add_10[0][0]                     
__________________________________________________________________________________________________
res2b_branch2a (Conv2D)         (None, 55, 55, 64)   16448       activation_32[0][0]              
__________________________________________________________________________________________________
bn2b_branch2a (BatchNormalizati (None, 55, 55, 64)   256         res2b_branch2a[0][0]             
__________________________________________________________________________________________________
activation_33 (Activation)      (None, 55, 55, 64)   0           bn2b_branch2a[0][0]              
__________________________________________________________________________________________________
res2b_branch2b (Conv2D)         (None, 55, 55, 64)   36928       activation_33[0][0]              
__________________________________________________________________________________________________
bn2b_branch2b (BatchNormalizati (None, 55, 55, 64)   256         res2b_branch2b[0][0]             
__________________________________________________________________________________________________
activation_34 (Activation)      (None, 55, 55, 64)   0           bn2b_branch2b[0][0]              
__________________________________________________________________________________________________
res2b_branch2c (Conv2D)         (None, 55, 55, 256)  16640       activation_34[0][0]              
__________________________________________________________________________________________________
bn2b_branch2c (BatchNormalizati (None, 55, 55, 256)  1024        res2b_branch2c[0][0]             
__________________________________________________________________________________________________
add_11 (Add)                    (None, 55, 55, 256)  0           bn2b_branch2c[0][0]              
                                                                 activation_32[0][0]              
__________________________________________________________________________________________________
activation_35 (Activation)      (None, 55, 55, 256)  0           add_11[0][0]                     
__________________________________________________________________________________________________
res2c_branch2a (Conv2D)         (None, 55, 55, 64)   16448       activation_35[0][0]              
__________________________________________________________________________________________________
bn2c_branch2a (BatchNormalizati (None, 55, 55, 64)   256         res2c_branch2a[0][0]             
__________________________________________________________________________________________________
activation_36 (Activation)      (None, 55, 55, 64)   0           bn2c_branch2a[0][0]              
__________________________________________________________________________________________________
res2c_branch2b (Conv2D)         (None, 55, 55, 64)   36928       activation_36[0][0]              
__________________________________________________________________________________________________
bn2c_branch2b (BatchNormalizati (None, 55, 55, 64)   256         res2c_branch2b[0][0]             
__________________________________________________________________________________________________
activation_37 (Activation)      (None, 55, 55, 64)   0           bn2c_branch2b[0][0]              
__________________________________________________________________________________________________
res2c_branch2c (Conv2D)         (None, 55, 55, 256)  16640       activation_37[0][0]              
__________________________________________________________________________________________________
bn2c_branch2c (BatchNormalizati (None, 55, 55, 256)  1024        res2c_branch2c[0][0]             
__________________________________________________________________________________________________
add_12 (Add)                    (None, 55, 55, 256)  0           bn2c_branch2c[0][0]              
                                                                 activation_35[0][0]              
__________________________________________________________________________________________________
activation_38 (Activation)      (None, 55, 55, 256)  0           add_12[0][0]                     
__________________________________________________________________________________________________
res3a_branch2a (Conv2D)         (None, 28, 28, 128)  32896       activation_38[0][0]              
__________________________________________________________________________________________________
bn3a_branch2a (BatchNormalizati (None, 28, 28, 128)  512         res3a_branch2a[0][0]             
__________________________________________________________________________________________________
activation_39 (Activation)      (None, 28, 28, 128)  0           bn3a_branch2a[0][0]              
__________________________________________________________________________________________________
res3a_branch2b (Conv2D)         (None, 28, 28, 128)  147584      activation_39[0][0]              
__________________________________________________________________________________________________
bn3a_branch2b (BatchNormalizati (None, 28, 28, 128)  512         res3a_branch2b[0][0]             
__________________________________________________________________________________________________
activation_40 (Activation)      (None, 28, 28, 128)  0           bn3a_branch2b[0][0]              
__________________________________________________________________________________________________
res3a_branch2c (Conv2D)         (None, 28, 28, 512)  66048       activation_40[0][0]              
__________________________________________________________________________________________________
res3a_branch1 (Conv2D)          (None, 28, 28, 512)  131584      activation_38[0][0]              
__________________________________________________________________________________________________
bn3a_branch2c (BatchNormalizati (None, 28, 28, 512)  2048        res3a_branch2c[0][0]             
__________________________________________________________________________________________________
bn3a_branch1 (BatchNormalizatio (None, 28, 28, 512)  2048        res3a_branch1[0][0]              
__________________________________________________________________________________________________
add_13 (Add)                    (None, 28, 28, 512)  0           bn3a_branch2c[0][0]              
                                                                 bn3a_branch1[0][0]               
__________________________________________________________________________________________________
activation_41 (Activation)      (None, 28, 28, 512)  0           add_13[0][0]                     
__________________________________________________________________________________________________
res3b_branch2a (Conv2D)         (None, 28, 28, 128)  65664       activation_41[0][0]              
__________________________________________________________________________________________________
bn3b_branch2a (BatchNormalizati (None, 28, 28, 128)  512         res3b_branch2a[0][0]             
__________________________________________________________________________________________________
activation_42 (Activation)      (None, 28, 28, 128)  0           bn3b_branch2a[0][0]              
__________________________________________________________________________________________________
res3b_branch2b (Conv2D)         (None, 28, 28, 128)  147584      activation_42[0][0]              
__________________________________________________________________________________________________
bn3b_branch2b (BatchNormalizati (None, 28, 28, 128)  512         res3b_branch2b[0][0]             
__________________________________________________________________________________________________
activation_43 (Activation)      (None, 28, 28, 128)  0           bn3b_branch2b[0][0]              
__________________________________________________________________________________________________
res3b_branch2c (Conv2D)         (None, 28, 28, 512)  66048       activation_43[0][0]              
__________________________________________________________________________________________________
bn3b_branch2c (BatchNormalizati (None, 28, 28, 512)  2048        res3b_branch2c[0][0]             
__________________________________________________________________________________________________
add_14 (Add)                    (None, 28, 28, 512)  0           bn3b_branch2c[0][0]              
                                                                 activation_41[0][0]              
__________________________________________________________________________________________________
activation_44 (Activation)      (None, 28, 28, 512)  0           add_14[0][0]                     
__________________________________________________________________________________________________
res3c_branch2a (Conv2D)         (None, 28, 28, 128)  65664       activation_44[0][0]              
__________________________________________________________________________________________________
bn3c_branch2a (BatchNormalizati (None, 28, 28, 128)  512         res3c_branch2a[0][0]             
__________________________________________________________________________________________________
activation_45 (Activation)      (None, 28, 28, 128)  0           bn3c_branch2a[0][0]              
__________________________________________________________________________________________________
res3c_branch2b (Conv2D)         (None, 28, 28, 128)  147584      activation_45[0][0]              
__________________________________________________________________________________________________
bn3c_branch2b (BatchNormalizati (None, 28, 28, 128)  512         res3c_branch2b[0][0]             
__________________________________________________________________________________________________
activation_46 (Activation)      (None, 28, 28, 128)  0           bn3c_branch2b[0][0]              
__________________________________________________________________________________________________
res3c_branch2c (Conv2D)         (None, 28, 28, 512)  66048       activation_46[0][0]              
__________________________________________________________________________________________________
bn3c_branch2c (BatchNormalizati (None, 28, 28, 512)  2048        res3c_branch2c[0][0]             
__________________________________________________________________________________________________
add_15 (Add)                    (None, 28, 28, 512)  0           bn3c_branch2c[0][0]              
                                                                 activation_44[0][0]              
__________________________________________________________________________________________________
activation_47 (Activation)      (None, 28, 28, 512)  0           add_15[0][0]                     
__________________________________________________________________________________________________
res3d_branch2a (Conv2D)         (None, 28, 28, 128)  65664       activation_47[0][0]              
__________________________________________________________________________________________________
bn3d_branch2a (BatchNormalizati (None, 28, 28, 128)  512         res3d_branch2a[0][0]             
__________________________________________________________________________________________________
activation_48 (Activation)      (None, 28, 28, 128)  0           bn3d_branch2a[0][0]              
__________________________________________________________________________________________________
res3d_branch2b (Conv2D)         (None, 28, 28, 128)  147584      activation_48[0][0]              
__________________________________________________________________________________________________
bn3d_branch2b (BatchNormalizati (None, 28, 28, 128)  512         res3d_branch2b[0][0]             
__________________________________________________________________________________________________
activation_49 (Activation)      (None, 28, 28, 128)  0           bn3d_branch2b[0][0]              
__________________________________________________________________________________________________
res3d_branch2c (Conv2D)         (None, 28, 28, 512)  66048       activation_49[0][0]              
__________________________________________________________________________________________________
bn3d_branch2c (BatchNormalizati (None, 28, 28, 512)  2048        res3d_branch2c[0][0]             
__________________________________________________________________________________________________
add_16 (Add)                    (None, 28, 28, 512)  0           bn3d_branch2c[0][0]              
                                                                 activation_47[0][0]              
__________________________________________________________________________________________________
activation_50 (Activation)      (None, 28, 28, 512)  0           add_16[0][0]                     
__________________________________________________________________________________________________
res4a_branch2a (Conv2D)         (None, 14, 14, 256)  131328      activation_50[0][0]              
__________________________________________________________________________________________________
bn4a_branch2a (BatchNormalizati (None, 14, 14, 256)  1024        res4a_branch2a[0][0]             
__________________________________________________________________________________________________
activation_51 (Activation)      (None, 14, 14, 256)  0           bn4a_branch2a[0][0]              
__________________________________________________________________________________________________
res4a_branch2b (Conv2D)         (None, 14, 14, 256)  590080      activation_51[0][0]              
__________________________________________________________________________________________________
bn4a_branch2b (BatchNormalizati (None, 14, 14, 256)  1024        res4a_branch2b[0][0]             
__________________________________________________________________________________________________
activation_52 (Activation)      (None, 14, 14, 256)  0           bn4a_branch2b[0][0]              
__________________________________________________________________________________________________
res4a_branch2c (Conv2D)         (None, 14, 14, 1024) 263168      activation_52[0][0]              
__________________________________________________________________________________________________
res4a_branch1 (Conv2D)          (None, 14, 14, 1024) 525312      activation_50[0][0]              
__________________________________________________________________________________________________
bn4a_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096        res4a_branch2c[0][0]             
__________________________________________________________________________________________________
bn4a_branch1 (BatchNormalizatio (None, 14, 14, 1024) 4096        res4a_branch1[0][0]              
__________________________________________________________________________________________________
add_17 (Add)                    (None, 14, 14, 1024) 0           bn4a_branch2c[0][0]              
                                                                 bn4a_branch1[0][0]               
__________________________________________________________________________________________________
activation_53 (Activation)      (None, 14, 14, 1024) 0           add_17[0][0]                     
__________________________________________________________________________________________________
res4b_branch2a (Conv2D)         (None, 14, 14, 256)  262400      activation_53[0][0]              
__________________________________________________________________________________________________
bn4b_branch2a (BatchNormalizati (None, 14, 14, 256)  1024        res4b_branch2a[0][0]             
__________________________________________________________________________________________________
activation_54 (Activation)      (None, 14, 14, 256)  0           bn4b_branch2a[0][0]              
__________________________________________________________________________________________________
res4b_branch2b (Conv2D)         (None, 14, 14, 256)  590080      activation_54[0][0]              
__________________________________________________________________________________________________
bn4b_branch2b (BatchNormalizati (None, 14, 14, 256)  1024        res4b_branch2b[0][0]             
__________________________________________________________________________________________________
activation_55 (Activation)      (None, 14, 14, 256)  0           bn4b_branch2b[0][0]              
__________________________________________________________________________________________________
res4b_branch2c (Conv2D)         (None, 14, 14, 1024) 263168      activation_55[0][0]              
__________________________________________________________________________________________________
bn4b_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096        res4b_branch2c[0][0]             
__________________________________________________________________________________________________
add_18 (Add)                    (None, 14, 14, 1024) 0           bn4b_branch2c[0][0]              
                                                                 activation_53[0][0]              
__________________________________________________________________________________________________
activation_56 (Activation)      (None, 14, 14, 1024) 0           add_18[0][0]                     
__________________________________________________________________________________________________
res4c_branch2a (Conv2D)         (None, 14, 14, 256)  262400      activation_56[0][0]              
__________________________________________________________________________________________________
bn4c_branch2a (BatchNormalizati (None, 14, 14, 256)  1024        res4c_branch2a[0][0]             
__________________________________________________________________________________________________
activation_57 (Activation)      (None, 14, 14, 256)  0           bn4c_branch2a[0][0]              
__________________________________________________________________________________________________
res4c_branch2b (Conv2D)         (None, 14, 14, 256)  590080      activation_57[0][0]              
__________________________________________________________________________________________________
bn4c_branch2b (BatchNormalizati (None, 14, 14, 256)  1024        res4c_branch2b[0][0]             
__________________________________________________________________________________________________
activation_58 (Activation)      (None, 14, 14, 256)  0           bn4c_branch2b[0][0]              
__________________________________________________________________________________________________
res4c_branch2c (Conv2D)         (None, 14, 14, 1024) 263168      activation_58[0][0]              
__________________________________________________________________________________________________
bn4c_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096        res4c_branch2c[0][0]             
__________________________________________________________________________________________________
add_19 (Add)                    (None, 14, 14, 1024) 0           bn4c_branch2c[0][0]              
                                                                 activation_56[0][0]              
__________________________________________________________________________________________________
activation_59 (Activation)      (None, 14, 14, 1024) 0           add_19[0][0]                     
__________________________________________________________________________________________________
res4d_branch2a (Conv2D)         (None, 14, 14, 256)  262400      activation_59[0][0]              
__________________________________________________________________________________________________
bn4d_branch2a (BatchNormalizati (None, 14, 14, 256)  1024        res4d_branch2a[0][0]             
__________________________________________________________________________________________________
activation_60 (Activation)      (None, 14, 14, 256)  0           bn4d_branch2a[0][0]              
__________________________________________________________________________________________________
res4d_branch2b (Conv2D)         (None, 14, 14, 256)  590080      activation_60[0][0]              
__________________________________________________________________________________________________
bn4d_branch2b (BatchNormalizati (None, 14, 14, 256)  1024        res4d_branch2b[0][0]             
__________________________________________________________________________________________________
activation_61 (Activation)      (None, 14, 14, 256)  0           bn4d_branch2b[0][0]              
__________________________________________________________________________________________________
res4d_branch2c (Conv2D)         (None, 14, 14, 1024) 263168      activation_61[0][0]              
__________________________________________________________________________________________________
bn4d_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096        res4d_branch2c[0][0]             
__________________________________________________________________________________________________
add_20 (Add)                    (None, 14, 14, 1024) 0           bn4d_branch2c[0][0]              
                                                                 activation_59[0][0]              
__________________________________________________________________________________________________
activation_62 (Activation)      (None, 14, 14, 1024) 0           add_20[0][0]                     
__________________________________________________________________________________________________
res4e_branch2a (Conv2D)         (None, 14, 14, 256)  262400      activation_62[0][0]              
__________________________________________________________________________________________________
bn4e_branch2a (BatchNormalizati (None, 14, 14, 256)  1024        res4e_branch2a[0][0]             
__________________________________________________________________________________________________
activation_63 (Activation)      (None, 14, 14, 256)  0           bn4e_branch2a[0][0]              
__________________________________________________________________________________________________
res4e_branch2b (Conv2D)         (None, 14, 14, 256)  590080      activation_63[0][0]              
__________________________________________________________________________________________________
bn4e_branch2b (BatchNormalizati (None, 14, 14, 256)  1024        res4e_branch2b[0][0]             
__________________________________________________________________________________________________
activation_64 (Activation)      (None, 14, 14, 256)  0           bn4e_branch2b[0][0]              
__________________________________________________________________________________________________
res4e_branch2c (Conv2D)         (None, 14, 14, 1024) 263168      activation_64[0][0]              
__________________________________________________________________________________________________
bn4e_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096        res4e_branch2c[0][0]             
__________________________________________________________________________________________________
add_21 (Add)                    (None, 14, 14, 1024) 0           bn4e_branch2c[0][0]              
                                                                 activation_62[0][0]              
__________________________________________________________________________________________________
activation_65 (Activation)      (None, 14, 14, 1024) 0           add_21[0][0]                     
__________________________________________________________________________________________________
res4f_branch2a (Conv2D)         (None, 14, 14, 256)  262400      activation_65[0][0]              
__________________________________________________________________________________________________
bn4f_branch2a (BatchNormalizati (None, 14, 14, 256)  1024        res4f_branch2a[0][0]             
__________________________________________________________________________________________________
activation_66 (Activation)      (None, 14, 14, 256)  0           bn4f_branch2a[0][0]              
__________________________________________________________________________________________________
res4f_branch2b (Conv2D)         (None, 14, 14, 256)  590080      activation_66[0][0]              
__________________________________________________________________________________________________
bn4f_branch2b (BatchNormalizati (None, 14, 14, 256)  1024        res4f_branch2b[0][0]             
__________________________________________________________________________________________________
activation_67 (Activation)      (None, 14, 14, 256)  0           bn4f_branch2b[0][0]              
__________________________________________________________________________________________________
res4f_branch2c (Conv2D)         (None, 14, 14, 1024) 263168      activation_67[0][0]              
__________________________________________________________________________________________________
bn4f_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096        res4f_branch2c[0][0]             
__________________________________________________________________________________________________
add_22 (Add)                    (None, 14, 14, 1024) 0           bn4f_branch2c[0][0]              
                                                                 activation_65[0][0]              
__________________________________________________________________________________________________
activation_68 (Activation)      (None, 14, 14, 1024) 0           add_22[0][0]                     
__________________________________________________________________________________________________
res5a_branch2a (Conv2D)         (None, 7, 7, 512)    524800      activation_68[0][0]              
__________________________________________________________________________________________________
bn5a_branch2a (BatchNormalizati (None, 7, 7, 512)    2048        res5a_branch2a[0][0]             
__________________________________________________________________________________________________
activation_69 (Activation)      (None, 7, 7, 512)    0           bn5a_branch2a[0][0]              
__________________________________________________________________________________________________
res5a_branch2b (Conv2D)         (None, 7, 7, 512)    2359808     activation_69[0][0]              
__________________________________________________________________________________________________
bn5a_branch2b (BatchNormalizati (None, 7, 7, 512)    2048        res5a_branch2b[0][0]             
__________________________________________________________________________________________________
activation_70 (Activation)      (None, 7, 7, 512)    0           bn5a_branch2b[0][0]              
__________________________________________________________________________________________________
res5a_branch2c (Conv2D)         (None, 7, 7, 2048)   1050624     activation_70[0][0]              
__________________________________________________________________________________________________
res5a_branch1 (Conv2D)          (None, 7, 7, 2048)   2099200     activation_68[0][0]              
__________________________________________________________________________________________________
bn5a_branch2c (BatchNormalizati (None, 7, 7, 2048)   8192        res5a_branch2c[0][0]             
__________________________________________________________________________________________________
bn5a_branch1 (BatchNormalizatio (None, 7, 7, 2048)   8192        res5a_branch1[0][0]              
__________________________________________________________________________________________________
add_23 (Add)                    (None, 7, 7, 2048)   0           bn5a_branch2c[0][0]              
                                                                 bn5a_branch1[0][0]               
__________________________________________________________________________________________________
activation_71 (Activation)      (None, 7, 7, 2048)   0           add_23[0][0]                     
__________________________________________________________________________________________________
res5b_branch2a (Conv2D)         (None, 7, 7, 512)    1049088     activation_71[0][0]              
__________________________________________________________________________________________________
bn5b_branch2a (BatchNormalizati (None, 7, 7, 512)    2048        res5b_branch2a[0][0]             
__________________________________________________________________________________________________
activation_72 (Activation)      (None, 7, 7, 512)    0           bn5b_branch2a[0][0]              
__________________________________________________________________________________________________
res5b_branch2b (Conv2D)         (None, 7, 7, 512)    2359808     activation_72[0][0]              
__________________________________________________________________________________________________
bn5b_branch2b (BatchNormalizati (None, 7, 7, 512)    2048        res5b_branch2b[0][0]             
__________________________________________________________________________________________________
activation_73 (Activation)      (None, 7, 7, 512)    0           bn5b_branch2b[0][0]              
__________________________________________________________________________________________________
res5b_branch2c (Conv2D)         (None, 7, 7, 2048)   1050624     activation_73[0][0]              
__________________________________________________________________________________________________
bn5b_branch2c (BatchNormalizati (None, 7, 7, 2048)   8192        res5b_branch2c[0][0]             
__________________________________________________________________________________________________
add_24 (Add)                    (None, 7, 7, 2048)   0           bn5b_branch2c[0][0]              
                                                                 activation_71[0][0]              
__________________________________________________________________________________________________
activation_74 (Activation)      (None, 7, 7, 2048)   0           add_24[0][0]                     
__________________________________________________________________________________________________
res5c_branch2a (Conv2D)         (None, 7, 7, 512)    1049088     activation_74[0][0]              
__________________________________________________________________________________________________
bn5c_branch2a (BatchNormalizati (None, 7, 7, 512)    2048        res5c_branch2a[0][0]             
__________________________________________________________________________________________________
activation_75 (Activation)      (None, 7, 7, 512)    0           bn5c_branch2a[0][0]              
__________________________________________________________________________________________________
res5c_branch2b (Conv2D)         (None, 7, 7, 512)    2359808     activation_75[0][0]              
__________________________________________________________________________________________________
bn5c_branch2b (BatchNormalizati (None, 7, 7, 512)    2048        res5c_branch2b[0][0]             
__________________________________________________________________________________________________
activation_76 (Activation)      (None, 7, 7, 512)    0           bn5c_branch2b[0][0]              
__________________________________________________________________________________________________
res5c_branch2c (Conv2D)         (None, 7, 7, 2048)   1050624     activation_76[0][0]              
__________________________________________________________________________________________________
bn5c_branch2c (BatchNormalizati (None, 7, 7, 2048)   8192        res5c_branch2c[0][0]             
__________________________________________________________________________________________________
add_25 (Add)                    (None, 7, 7, 2048)   0           bn5c_branch2c[0][0]              
                                                                 activation_74[0][0]              
__________________________________________________________________________________________________
activation_77 (Activation)      (None, 7, 7, 2048)   0           add_25[0][0]                     
__________________________________________________________________________________________________
avg_pool (AveragePooling2D)     (None, 1, 1, 2048)   0           activation_77[0][0]              
__________________________________________________________________________________________________
flatten_1 (Flatten)             (None, 2048)         0           avg_pool[0][0]                   
__________________________________________________________________________________________________
fc1000 (Dense)                  (None, 1000)         2049000     flatten_1[0][0]                  
==================================================================================================
Total params: 25,636,712
Trainable params: 25,583,592
Non-trainable params: 53,120
__________________________________________________________________________________________________

The reason why we chose ResNet50 is because the top layer of this network is a GAP layer, immediately followed by a fully connected layer with a softmax activation function that aims to classify our input images' classes, As we will soon see, this is essentially what CAM requires.

Given this pre-trained model, we will load some images to test it out. Note that for CAM to work, our model should be reasonably strong to begin with. We can't use a model that failed to predict the image was a dog and expect it to give us back the region it used to predict this image was a dog.

In [13]:
# change default figure and font size
plt.rcParams['figure.figsize'] = 8, 6
plt.rcParams['font.size'] = 12

# visualize the image that we'll be working with
img_dir = 'images'
img_path = os.path.join(img_dir, 'bmw.png')
img = image.load_img(img_path)
plt.imshow(img)
plt.show()
In [14]:
def read_and_process_img(img_path, target_size=(224, 224)):
    """
    Reads in a single image, resize it to the specified target size
    and performs the same preprocessing on the image as the original
    pre-trained model.
    """
    img = image.load_img(img_path, target_size=target_size)
    img = image.img_to_array(img)

    # keras works with batches of images, since we only have 1 image
    # here, we need to add an additional dimension to turn it into
    # shape [samples, size1, size2, channels]
    img = np.expand_dims(img, axis=0)

    # different pre-trained model preprocess the images differently
    # we also preprocess our images the same way to be consistent
    img = preprocess_input(img)
    return img

    
    
img_dir = 'images'
img_path = os.path.join(img_dir, 'bmw.png')
img = read_and_process_img(img_path)

# print a pixel's channel value
print(img[0, 0, 0])
[119.061  77.221  66.32 ]

In case your curious, the following two code cells shows what's happening underneath the hood when we call resnet50's preprocess_input function.

In [15]:
# the pre-trained resnet was trained on images of size 224 x 224
target_size = 224, 224
img = image.load_img(img_path, target_size=target_size)
img = image.img_to_array(img)
img = np.expand_dims(img, axis=0)

# we first transform the image's channel
# from RGB to BGR, if we look at the value
# printed below, see that the first and third
# value are reversed
print(img[0, 0, 0])
img = img[..., ::-1]
print(img[0, 0, 0])
[190. 194. 223.]
[223. 194. 190.]
In [16]:
# after that we subtract each channel
# with the overall mean (103.939 was
# the mean of the B channel across the image
# that the model was trained on) 
mean = [103.939, 116.779, 123.68]

img[..., 0] -= mean[0]
img[..., 1] -= mean[1]
img[..., 2] -= mean[2]

# hopefully, we can see that the result matches
# the preprocess_input method, note that different
# pre-trained model's preprocessing step might differ
# from one to another, it's always a good idea to
# check the source code
img[0, 0, 0]
Out[16]:
array([119.061,  77.221,  66.32 ], dtype=float32)

Given our image, we can pass it directly to our pre-trained model's .predict method to generate the predicted probability of this image falling into ImageNet's 1000 class labels. After that we pass the output probability into a utility function decode_predictions to give us a list of ImageNet class label IDs, “human-readable” labels, and the probability associated with the labels.

In [17]:
from keras.applications.resnet50 import decode_predictions


img = read_and_process_img(img_path)
preds = model.predict(img)
print('prediction shape: ', preds.shape)
print()

preds = decode_predictions(preds, top=5)
print(preds)
print()

# hopefully, we can say that the prediction does in fact makes sense
for i, (imagenetID, label, prob) in enumerate(preds[0]):
    print('{}. {}: {:.2f}%'.format(i + 1, label, prob * 100))
prediction shape:  (1, 1000)

[[('n03100240', 'convertible', 0.954938), ('n04285008', 'sports_car', 0.037788592), ('n02974003', 'car_wheel', 0.004751421), ('n03594945', 'jeep', 0.00059314625), ('n02814533', 'beach_wagon', 0.0005832641)]]

1. convertible: 95.49%
2. sports_car: 3.78%
3. car_wheel: 0.48%
4. jeep: 0.06%
5. beach_wagon: 0.06%

Now that we've confirmed the model was able to classify our sample image pretty correctly, let's move on to the inner workings of CAM.

For our ResNet50, the Activation, AveragePooling2D (this is essentially our GAP layer), Flatten and Dense layers towards the end of the network are what we're interested in. We'll begin with the Activation layer. For ResNet50, this layer contains 2048 activation maps, each with dimensions $7 \times 7$. Let $f_k$ represent the $k$-th activation map, where $k \in \{ 1, \ldots, 2048 \}$. The following AveragePooling2D layer reduces the size of the preceding layer to (1, 1, 2048) by taking the average of each feature map and the next Flatten layer merely flattens the input, without resulting in any change to the information contained in the previous GAP layer.

Our final Dense layer will be of size (2048, 1000), if only we look at the dominated predicted class's weights, $w_k$, then this weight is essentially connecting the $k$-th node in the Flatten layer to the output node corresponding to the predicted image category, which depicts the strength of each activation map for our predicted class.

The process can be summarized into the following diagram:

Or the following formula:

\begin{align} M_c(x, y) &= \sum_k w^c_k \cdot f_k (x, y) \end{align}

Where:

  • $f_k(x, y)$ : Represents the last convolutional layer's activation for activation unit $k$ at the spatial location $x$, $y$. Spatial location here refers to the width and height location of the image.
  • $w^c_k$ : Corresponds to the dense layer's weight for a given class $c$ for the activation unit $k$
  • $M_c(x, y)$ : The output is the class activation map for class $c$, which indicates the importance of the activation at the spatial location $x$ and $y$ for classifying the image to class $c$.
In [18]:
def get_resnet50_cam_info(model):
    # we need the output of the activation layer right before the
    # global average pooling (gap) layer and the last dense/softmax
    # layer that generates the class prediction
    before_gap_layer = model.layers[-4]
    class_pred_layer = model.layers[-1]

    outputs = before_gap_layer.output, class_pred_layer.output
    resnet50_cam_layers = Model(inputs=model.input, outputs=outputs)

    # only access the first element of weights, we won't be needing the bias term here
    class_weights = class_pred_layer.get_weights()[0]
    return resnet50_cam_layers, class_weights


resnet50_cam_layers, class_weights = get_resnet50_cam_info(model)
class_weights.shape
Out[18]:
(2048, 1000)
In [19]:
from scipy.ndimage import zoom


def create_cam(img, resnet50_cam_layers, class_weights, target_size):
    # before_gap_output will be of shape [1, 7, 7, 2048] for resnet50
    before_gap_output, prediction = resnet50_cam_layers.predict(img)

    predicted_class = np.argmax(prediction)
    dominate_class_weight = class_weights[:, predicted_class]

    # the resizing parameters are hard-coded here for clarity purposes only,
    # we resize the shape of the activation so we can perform a dot product with
    # the dominated class weight, then we reshape it back to the target image size
    # so we can overlay the class activation map on top of our image later
    before_gap_output = np.squeeze(before_gap_output).reshape((-1, 2048))
    cam = np.dot(before_gap_output, dominate_class_weight).reshape((7, 7))

    # https://stackoverflow.com/questions/13242382/resampling-a-numpy-array-representing-an-image
    cam = zoom(cam, (32, 32), order=1)
    return cam
    

img = read_and_process_img(img_path)
cam = create_cam(img, resnet50_cam_layers, class_weights, target_size)
cam.shape
Out[19]:
(224, 224)

Some quick side note regarding the zoom function. When performing various transformations on our image, we often need to employ some sort of interpolation in order to obtain a good image quality after the transformation, e.g. after rotation, we found out that a pixel should have the same color as pixel (3.4, 5.6) in the original image. The only problem is that there is no such thing as fractional pixels, thus we need to find a way to impute these values.

For scipy's zoom function, we can specify different interpolation techniques.

  • Nearest is order=0. This is the interpolation that requires the least processing time of all the interpolation algorithms because it only considers one pixel, the closest one to the interpolated point. This has the effect of making each pixel bigger.
  • Bilinear interpolation would be order=1. Bilinear interpolation considers the closest $2 \times 2$ neighborhood of known pixel values surrounding the unknown pixel. It then takes a weighted average of these 4 pixels to arrive at its final interpolated value. This results in much smoother looking images than nearest neighbor.
  • Bicubic is the default (order=3). Bicubic goes one step beyond bilinear by considering the closest $4 \times 4$ neighborhood of known pixels, a total of 16 pixels. Since these are at various distances from the unknown pixel, closer pixels are given a higher weighting in the calculation. Bicubic produces noticeably sharper images than the previous two methods, and is perhaps the ideal combination of processing time (there're higher order available, but as we can image they are often times expensive to compute and can have diminishing returns) and output quality. For this reason it is a standard in many image editing programs (including Adobe Photoshop), printer drivers and in-camera interpolation.

The following code uses OpenCV to load image, convert it from BGR format to RGB as the default channel ordering for OpenCV is BGR, and resizes the image to $224 \times 224$. After that we plot the image and overlay our class activation map over the original image.

OpenCV is often preferred when we're doing a lot of image manipulation and speed is a huge concern (It is often times faster than the PIL library, which is what is used underneath the hood of keras' image preprocessing function). OpenCV is a huge topic by itself and there are various links floating around the web to if you wish to dig deeper. e.g. Blog: Basic Image Manipulations in Python and OpenCV: Resizing (scaling), Rotating, and Cropping

In [20]:
img = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
img = cv2.resize(img, target_size)

# we can achieve the same thing here without using opencv,
# it was merely used to ensure we know about the tool's existence
# img = image.load_img(img_path, target_size=target_size)

fig, ax = plt.subplots()
ax.imshow(img, alpha=0.5)
ax.imshow(cam, cmap='jet', alpha=0.5)

plt.show()

We can organize the implementation above into a class and generate class activation maps on multiple images.

In [21]:
from explainer import CAMExplainer

cam_explainer = CAMExplainer(model, target_size)

img_files = ['soccer_ball.jpg', 'jemma.png', 'office.png']
for img_file in img_files:  
    img_path = os.path.join(img_dir, img_file)
    cam_explainer.fit(img_path).plot()