Exercise 12.2 - Solution

Activation maximization

In this task, we use the approach of activation maximization to visualize to which patterns features of a CNN trained using on MNIST are sensitive. This will give us a deeper understanding of the working principle of CNNs.

In [9]:
import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt

KTF = keras.backend
layers = keras.layers

print("keras", keras.__version__)
keras 2.4.0

Download and preprocess data

In [2]:
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.astype(np.float32)[...,np.newaxis] / 255.
x_test = x_test.astype(np.float32)[...,np.newaxis] / 255.
y_train = keras.utils.to_categorical(y_train, 10)
y_test = keras.utils.to_categorical(y_test, 10)

Set up a convolutional neural network with at least 4 CNN layers.

In [3]:
model = keras.models.Sequential([
    layers.Conv2D(16, (3,3), activation='relu', padding='same', input_shape=(28,28,1), name='conv2d_1'),
    layers.Conv2D(32, (3,3), activation='relu', padding='same', name='conv2d_2'),
    layers.MaxPooling2D((2,2)),
    layers.Conv2D(64, (3,3), activation='relu', padding='same', name='conv2d_3'),
    layers.MaxPooling2D((2,2)),
    layers.Conv2D(128, (3,3), activation='relu', padding='same', name='conv2d_4'),
    layers.GlobalMaxPooling2D(),
    layers.Dropout(0.5),
    layers.Dense(10),
    layers.Activation('softmax', name='softmax_layer')])

model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_1 (Conv2D)            (None, 28, 28, 16)        160       
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 28, 28, 32)        4640      
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 14, 14, 32)        0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 14, 14, 64)        18496     
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 7, 7, 64)          0         
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 7, 7, 128)         73856     
_________________________________________________________________
global_max_pooling2d (Global (None, 128)               0         
_________________________________________________________________
dropout (Dropout)            (None, 128)               0         
_________________________________________________________________
dense (Dense)                (None, 10)                1290      
_________________________________________________________________
softmax_layer (Activation)   (None, 10)                0         
=================================================================
Total params: 98,442
Trainable params: 98,442
Non-trainable params: 0
_________________________________________________________________

compile and train model

In [4]:
model.compile(
    loss='categorical_crossentropy',
    optimizer=keras.optimizers.Adam(lr=1e-3),
    metrics=['accuracy'])


results = model.fit(x_train, y_train,
                    batch_size=100,
                    epochs=3,
                    verbose=1,
                    validation_split=0.1
                    )
Epoch 1/3
540/540 [==============================] - 45s 83ms/step - loss: 0.9091 - accuracy: 0.6911 - val_loss: 0.0624 - val_accuracy: 0.9812
Epoch 2/3
540/540 [==============================] - 47s 87ms/step - loss: 0.1476 - accuracy: 0.9575 - val_loss: 0.0420 - val_accuracy: 0.9870
Epoch 3/3
540/540 [==============================] - 44s 82ms/step - loss: 0.0999 - accuracy: 0.9712 - val_loss: 0.0334 - val_accuracy: 0.9900

Implementation of activation maximization

Select a layer you want to visualize and perform activation maximization.

In [5]:
gradient_updates = 50
step_size = 1.

def normalize(x):
    '''Normalize gradients via l2 norm'''
    return x / (KTF.sqrt(KTF.mean(KTF.square(x))) + KTF.epsilon())
In [6]:
visualized_feature = []
layer_dict = layer_dict = dict([(layer.name, layer) for layer in model.layers[:]])
layer_name = "conv2d_3"

layer_output = layer_dict[layer_name].output
sub_model = keras.models.Model([model.inputs], [layer_output])

for filter_index in range(layer_output.shape[-1]):

    print('Processing feature map %d' % (filter_index+1))
    # activation maximization
    input_img = KTF.variable(np.random.uniform(0,1, (1, 28, 28, 1)))
    for i in range(gradient_updates):

        with tf.GradientTape() as gtape:

            layer_output = sub_model(input_img)
            loss = KTF.mean(layer_output[..., filter_index])  # objective
            grads = gtape.gradient(loss, input_img)
            grads = normalize(grads)
            input_img.assign_add(step_size * grads)

    visualized_feature.append(input_img.numpy())  # cast to numpy array
Processing feature map 1
Processing feature map 2
Processing feature map 3
Processing feature map 4
Processing feature map 5
Processing feature map 6
Processing feature map 7
Processing feature map 8
Processing feature map 9
Processing feature map 10
Processing feature map 11
Processing feature map 12
Processing feature map 13
Processing feature map 14
Processing feature map 15
Processing feature map 16
Processing feature map 17
Processing feature map 18
Processing feature map 19
Processing feature map 20
Processing feature map 21
Processing feature map 22
Processing feature map 23
Processing feature map 24
Processing feature map 25
Processing feature map 26
Processing feature map 27
Processing feature map 28
Processing feature map 29
Processing feature map 30
Processing feature map 31
Processing feature map 32
Processing feature map 33
Processing feature map 34
Processing feature map 35
Processing feature map 36
Processing feature map 37
Processing feature map 38
Processing feature map 39
Processing feature map 40
Processing feature map 41
Processing feature map 42
Processing feature map 43
Processing feature map 44
Processing feature map 45
Processing feature map 46
Processing feature map 47
Processing feature map 48
Processing feature map 49
Processing feature map 50
Processing feature map 51
Processing feature map 52
Processing feature map 53
Processing feature map 54
Processing feature map 55
Processing feature map 56
Processing feature map 57
Processing feature map 58
Processing feature map 59
Processing feature map 60
Processing feature map 61
Processing feature map 62
Processing feature map 63
Processing feature map 64

Plot images to visualize to which patterns the respective feature maps are sensitive.

In [7]:
def deprocess_image(x):
    # reprocess visualization to format of "MNIST images"
    x -= x.mean()
    x /= (x.std() + KTF.epsilon())
    # x *= 0.1
    x += 0.5
    x *= 255
    x = np.clip(x, 0, 255).astype('uint8')
    return x
In [8]:
plt.figure(figsize=(10,10))

for i, feature_ in enumerate(visualized_feature):
    feature_image = deprocess_image(feature_)
    ax = plt.subplot(8, 8, 1+i)
    plt.imshow(feature_image.squeeze())
    ax.axis('off')
    plt.title("feature %s" % i)
    
plt.tight_layout()