Exercise 12.2

Activation maximization

In this task, we use the approach of activation maximization to visualize to which patterns features of a CNN trained using on MNIST are sensitive. This will give us a deeper understanding of the working principle of CNNs.

In [3]:
import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt

KTF = keras.backend
layers = keras.layers

print("keras", keras.__version__)
keras 2.4.0

Download and preprocess data

In [2]:
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.astype(np.float32)[...,np.newaxis] / 255.
x_test = x_test.astype(np.float32)[...,np.newaxis] / 255.
y_train = keras.utils.to_categorical(y_train, 10)
y_test = keras.utils.to_categorical(y_test, 10)

Set up a convolutional neural network with at least 4 CNN layers.

In [ ]:
model = keras.models.Sequential()

model.summary()

compile and train model

In [ ]:
model.compile(
    loss='categorical_crossentropy',
    optimizer=keras.optimizers.Adam(lr=1e-3),
    metrics=['accuracy'])


results = model.fit(x_train, y_train,
                    batch_size=100,
                    epochs=3,
                    verbose=1,
                    validation_split=0.1
                    )

Implementation of activation maximization

Select a layer you want to visualize and perform activation maximization.

In [ ]:
gradient_updates = 50
step_size = 1.

def normalize(x):
    '''Normalize gradients via l2 norm'''
    return x / (KTF.sqrt(KTF.mean(KTF.square(x))) + KTF.epsilon())

In the following, implement activation maximization to visualize to which patterns a specific feature map is sensitive:

  • Start from uniform distributed noise 'images' (note that the shape has to be (1, 28, 28, 1), as we use a batch size of 1).
  • Choose one specific feature map using 'filter_index'.
  • Create a scalar loss as discussed in Chapter 12 (maximize the average feature map activation).
  • Thereafter, add the calculated gradients to your start image (gradient ascent step) and repeat the procedure using gradient_updates = 50. You can calculate the gradients using the following expressions:
    with tf.GradientTape() as gtape: grads = gtape.gradient(YOUR_OBJECTIVE, THE_VARIABLE_YOU_WANT_TO_OPTIMIZE) grads = normalize(grads)

  • Finally, implement the gradient ascent step (you may use assign_sub or assign_add to adapt the parameters) and perform 50 updates.

Remember to construct a Keras variable for the input (we want to find an input that 'maximizes' the output, so we build an input that holds adaptive parameters which we can train using TensorFlow / Keras) The following code snippet may help you to implement the maximization:

In [ ]:
visualized_feature = []
layer_dict = layer_dict = dict([(layer.name, layer) for layer in model.layers[:]])
layer_name = "conv2d_3"

layer_output = layer_dict[layer_name].output
sub_model = keras.models.Model([model.inputs], [layer_output])

for filter_index in range(layer_output.shape[-1]):  # iterate over fiters

    print('Processing filter %d' % (filter_index+1))
   
    input_img = KTF.variable([0]) # instead of '[0]' use noise as the (start) input image with correct shape

    for i in range(gradient_updates):

        with tf.GradientTape() as gtape:
            # define a scalar loss using Keras.
            # remember: You would like to maximize the activations in the respective feature map!
            loss = 0  # <--: define your loss HERE
    

Plot images to visualize to which patterns the respective feature maps are sensitive.

In [ ]:
def deprocess_image(x):
    # reprocess visualization to format of "MNIST images"
    x -= x.mean()
    x /= (x.std() + KTF.epsilon())
    # x *= 0.1
    x += 0.5
    x *= 255
    x = np.clip(x, 0, 255).astype('uint8')
    return x
In [ ]:
plt.figure(figsize=(10,10))

for i, feature_ in enumerate(visualized_feature):
    feature_image = deprocess_image(feature_)
    ax = plt.subplot(8,8, 1+i, )
    plt.imshow(feature_image.squeeze())
    ax.axis('off')
    plt.title("feature %s" % i)
    
plt.tight_layout()