import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt
KTF = keras.backend
layers = keras.layers
keras version 2.4.0
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.astype(np.float32)[...,np.newaxis] / 255.
x_test = x_test.astype(np.float32)[...,np.newaxis] / 255.
y_train = keras.utils.to_categorical(y_train, 10)
y_test = keras.utils.to_categorical(y_test, 10)
model = keras.models.Sequential([
layers.Conv2D(16, (3,3), activation='relu', padding='same', input_shape=(28,28,1), name='conv2d_1'),
layers.Conv2D(32, (3,3), activation='relu', padding='same', name='conv2d_2'),
layers.MaxPooling2D((2,2)),
layers.Conv2D(64, (3,3), activation='relu', padding='same', name='conv2d_3'),
layers.MaxPooling2D((2,2)),
layers.Conv2D(128, (3,3), activation='relu', padding='same', name='conv2d_4'),
layers.GlobalMaxPooling2D(),
layers.Dropout(0.5),
layers.Dense(10),
layers.Activation('softmax', name='softmax_layer')])
model.summary()
Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= conv2d_1 (Conv2D) (None, 28, 28, 16) 160 _________________________________________________________________ conv2d_2 (Conv2D) (None, 28, 28, 32) 4640 _________________________________________________________________ max_pooling2d (MaxPooling2D) (None, 14, 14, 32) 0 _________________________________________________________________ conv2d_3 (Conv2D) (None, 14, 14, 64) 18496 _________________________________________________________________ max_pooling2d_1 (MaxPooling2 (None, 7, 7, 64) 0 _________________________________________________________________ conv2d_4 (Conv2D) (None, 7, 7, 128) 73856 _________________________________________________________________ global_max_pooling2d (Global (None, 128) 0 _________________________________________________________________ dropout (Dropout) (None, 128) 0 _________________________________________________________________ dense (Dense) (None, 10) 1290 _________________________________________________________________ softmax_layer (Activation) (None, 10) 0 ================================================================= Total params: 98,442 Trainable params: 98,442 Non-trainable params: 0 _________________________________________________________________
model.compile(
loss='categorical_crossentropy',
optimizer=keras.optimizers.Adam(lr=1e-3),
metrics=['accuracy'])
results = model.fit(x_train, y_train,
batch_size=100,
epochs=3,
verbose=1,
validation_split=0.1
)
Epoch 1/3 540/540 [==============================] - 45s 83ms/step - loss: 0.9091 - accuracy: 0.6911 - val_loss: 0.0624 - val_accuracy: 0.9812 Epoch 2/3 540/540 [==============================] - 47s 87ms/step - loss: 0.1476 - accuracy: 0.9575 - val_loss: 0.0420 - val_accuracy: 0.9870 Epoch 3/3 540/540 [==============================] - 44s 82ms/step - loss: 0.0999 - accuracy: 0.9712 - val_loss: 0.0334 - val_accuracy: 0.9900
Select a layer you want to visualize and perform activation maximization.
gradient_updates = 50
step_size = 1.
def normalize(x):
'''Normalize gradients via l2 norm'''
return x / (KTF.sqrt(KTF.mean(KTF.square(x))) + KTF.epsilon())
visualized_feature = []
layer_dict = layer_dict = dict([(layer.name, layer) for layer in model.layers[:]])
layer_name = "conv2d_3"
layer_output = layer_dict[layer_name].output
sub_model = keras.models.Model([model.inputs], [layer_output])
for filter_index in range(layer_output.shape[-1]):
print('Processing feature map %d' % (filter_index+1))
# activation maximization
input_img = KTF.variable(np.random.uniform(0,1, (1, 28, 28, 1)))
for i in range(gradient_updates):
with tf.GradientTape() as gtape:
layer_output = sub_model(input_img)
loss = KTF.mean(layer_output[..., filter_index]) # objective
grads = gtape.gradient(loss, input_img)
grads = normalize(grads)
input_img.assign_add(step_size * grads)
visualized_feature.append(input_img.numpy()) # cast to numpy array
Processing feature map 1 Processing feature map 2 Processing feature map 3 Processing feature map 4 Processing feature map 5 Processing feature map 6 Processing feature map 7 Processing feature map 8 Processing feature map 9 Processing feature map 10 Processing feature map 11 Processing feature map 12 Processing feature map 13 Processing feature map 14 Processing feature map 15 Processing feature map 16 Processing feature map 17 Processing feature map 18 Processing feature map 19 Processing feature map 20 Processing feature map 21 Processing feature map 22 Processing feature map 23 Processing feature map 24 Processing feature map 25 Processing feature map 26 Processing feature map 27 Processing feature map 28 Processing feature map 29 Processing feature map 30 Processing feature map 31 Processing feature map 32 Processing feature map 33 Processing feature map 34 Processing feature map 35 Processing feature map 36 Processing feature map 37 Processing feature map 38 Processing feature map 39 Processing feature map 40 Processing feature map 41 Processing feature map 42 Processing feature map 43 Processing feature map 44 Processing feature map 45 Processing feature map 46 Processing feature map 47 Processing feature map 48 Processing feature map 49 Processing feature map 50 Processing feature map 51 Processing feature map 52 Processing feature map 53 Processing feature map 54 Processing feature map 55 Processing feature map 56 Processing feature map 57 Processing feature map 58 Processing feature map 59 Processing feature map 60 Processing feature map 61 Processing feature map 62 Processing feature map 63 Processing feature map 64
def deprocess_image(x):
# reprocess visualization to format of "MNIST images"
x -= x.mean()
x /= (x.std() + KTF.epsilon())
# x *= 0.1
x += 0.5
x *= 255
x = np.clip(x, 0, 255).astype('uint8')
return x
plt.figure(figsize=(10,10))
for i, feature_ in enumerate(visualized_feature):
feature_image = deprocess_image(feature_)
ax = plt.subplot(8, 8, 1+i)
plt.imshow(feature_image.squeeze())
ax.axis('off')
plt.title("feature %s" % i)
plt.tight_layout()