In this notebook we will explore the use of image gradients for generating new images.
When training a model, we define a loss function which measures our current unhappiness with the model's performance; we then use backpropagation to compute the gradient of the loss with respect to the model parameters, and perform gradient descent on the model parameters to minimize the loss.
Here we will do something slightly different. We will start from a convolutional neural network model which has been pretrained to perform image classification on the ImageNet dataset. We will use this model to define a loss function which quantifies our current unhappiness with our image, then use backpropagation to compute the gradient of this loss with respect to the pixels of the image. We will then keep the model fixed, and perform gradient descent on the image to synthesize a new image which minimizes the loss.
In this notebook we will explore three techniques for image generation:
This notebook uses PyTorch; we have provided another notebook which explores the same concepts in TensorFlow. You only need to complete one of these two notebooks.
import torch
from torch.autograd import Variable
import torchvision
import torchvision.transforms as T
import random
import numpy as np
from scipy.ndimage.filters import gaussian_filter1d
import matplotlib.pyplot as plt
from cs231n.image_utils import SQUEEZENET_MEAN, SQUEEZENET_STD
from PIL import Image
%matplotlib inline
plt.rcParams['figure.figsize'] = (10.0, 8.0) # set default size of plots
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'
Our pretrained model was trained on images that had been preprocessed by subtracting the per-color mean and dividing by the per-color standard deviation. We define a few helper functions for performing and undoing this preprocessing. You don't need to do anything in this cell.
def preprocess(img, size=224):
transform = T.Compose([
T.Scale(size),
T.ToTensor(),
T.Normalize(mean=SQUEEZENET_MEAN.tolist(),
std=SQUEEZENET_STD.tolist()),
T.Lambda(lambda x: x[None]),
])
return transform(img)
def deprocess(img, should_rescale=True):
transform = T.Compose([
T.Lambda(lambda x: x[0]),
T.Normalize(mean=[0, 0, 0], std=(1.0 / SQUEEZENET_STD).tolist()),
T.Normalize(mean=(-SQUEEZENET_MEAN).tolist(), std=[1, 1, 1]),
T.Lambda(rescale) if should_rescale else T.Lambda(lambda x: x),
T.ToPILImage(),
])
return transform(img)
def rescale(x):
low, high = x.min(), x.max()
x_rescaled = (x - low) / (high - low)
return x_rescaled
def blur_image(X, sigma=1):
X_np = X.cpu().clone().numpy()
X_np = gaussian_filter1d(X_np, sigma, axis=2)
X_np = gaussian_filter1d(X_np, sigma, axis=3)
X.copy_(torch.Tensor(X_np).type_as(X))
return X
For all of our image generation experiments, we will start with a convolutional neural network which was pretrained to perform image classification on ImageNet. We can use any model here, but for the purposes of this assignment we will use SqueezeNet [1], which achieves accuracies comparable to AlexNet but with a significantly reduced parameter count and computational complexity.
Using SqueezeNet rather than AlexNet or VGG or ResNet means that we can easily perform all image generation experiments on CPU.
[1] Iandola et al, "SqueezeNet: AlexNet-level accuracy with 50x fewer parameters and < 0.5MB model size", arXiv 2016
# Download and load the pretrained SqueezeNet model.
model = torchvision.models.squeezenet1_1(pretrained=True)
# We don't want to train the model, so tell PyTorch not to compute gradients
# with respect to model parameters.
for param in model.parameters():
param.requires_grad = False
We have provided a few example images from the validation set of the ImageNet ILSVRC 2012 Classification dataset. To download these images, change to cs231n/datasets/
and run get_imagenet_val.sh
.
Since they come from the validation set, our pretrained model did not see these images during training.
Run the following cell to visualize some of these images, along with their ground-truth labels.
from cs231n.data_utils import load_imagenet_val
X, y, class_names = load_imagenet_val(num=5)
plt.figure(figsize=(12, 6))
for i in range(5):
plt.subplot(1, 5, i + 1)
plt.imshow(X[i])
plt.title(class_names[y[i]])
plt.axis('off')
plt.gcf().tight_layout()
Using this pretrained model, we will compute class saliency maps as described in Section 3.1 of [2].
A saliency map tells us the degree to which each pixel in the image affects the classification score for that image. To compute it, we compute the gradient of the unnormalized score corresponding to the correct class (which is a scalar) with respect to the pixels of the image. If the image has shape (3, H, W)
then this gradient will also have shape (3, H, W)
; for each pixel in the image, this gradient tells us the amount by which the classification score will change if the pixel changes by a small amount. To compute the saliency map, we take the absolute value of this gradient, then take the maximum value over the 3 input channels; the final saliency map thus has shape (H, W)
and all entries are nonnegative.
[2] Karen Simonyan, Andrea Vedaldi, and Andrew Zisserman. "Deep Inside Convolutional Networks: Visualising Image Classification Models and Saliency Maps", ICLR Workshop 2014.
gather
method¶Recall in Assignment 1 you needed to select one element from each row of a matrix; if s
is an numpy array of shape (N, C)
and y
is a numpy array of shape (N,
) containing integers 0 <= y[i] < C
, then s[np.arange(N), y]
is a numpy array of shape (N,)
which selects one element from each element in s
using the indices in y
.
In PyTorch you can perform the same operation using the gather()
method. If s
is a PyTorch Tensor or Variable of shape (N, C)
and y
is a PyTorch Tensor or Variable of shape (N,)
containing longs in the range 0 <= y[i] < C
, then
s.gather(1, y.view(-1, 1)).squeeze()
will be a PyTorch Tensor (or Variable) of shape (N,)
containing one entry from each row of s
, selected according to the indices in y
.
run the following cell to see an example.
You can also read the documentation for the gather method and the squeeze method.
# Example of using gather to select one entry from each row in PyTorch
def gather_example():
N, C = 4, 5
s = torch.randn(N, C)
y = torch.LongTensor([1, 2, 1, 3])
print(s)
print(y)
print(s.gather(1, y.view(-1, 1)).squeeze())
gather_example()
-0.9695 0.3291 -0.4649 -0.7073 1.7230 -0.9341 1.8693 -0.8987 0.0674 0.9908 0.5223 -1.0662 -0.6448 0.9887 -0.1528 1.2723 0.6175 0.2850 -0.0211 -2.0851 [torch.FloatTensor of size 4x5] 1 2 1 3 [torch.LongTensor of size 4] 0.3291 -0.8987 -1.0662 -0.0211 [torch.FloatTensor of size 4]
def compute_saliency_maps(X, y, model):
"""
Compute a class saliency map using the model for images X and labels y.
Input:
- X: Input images; Tensor of shape (N, 3, H, W)
- y: Labels for X; LongTensor of shape (N,)
- model: A pretrained CNN that will be used to compute the saliency map.
Returns:
- saliency: A Tensor of shape (N, H, W) giving the saliency maps for the input
images.
"""
# Make sure the model is in "test" mode
model.eval()
# Wrap the input tensors in Variables
X_var = Variable(X, requires_grad=True)
y_var = Variable(y)
saliency = None
##############################################################################
# TODO: Implement this function. Perform a forward and backward pass through #
# the model to compute the gradient of the correct class score with respect #
# to each input image. You first want to compute the loss over the correct #
# scores, and then compute the gradients with a backward pass. #
##############################################################################
scores = model(X_var)
scores = scores.gather(1, y_var.view(-1, 1)).squeeze()
loss = -torch.sum(torch.log(scores))
loss.backward()
saliency = X_var.grad.data
saliency = saliency.abs()
saliency, idx = saliency.max(dim=1)
##############################################################################
# END OF YOUR CODE #
##############################################################################
return saliency.squeeze()
Once you have completed the implementation in the cell above, run the following to visualize some class saliency maps on our example images from the ImageNet validation set:
def show_saliency_maps(X, y):
# Convert X and y from numpy arrays to Torch Tensors
X_tensor = torch.cat([preprocess(Image.fromarray(x)) for x in X], dim=0)
y_tensor = torch.LongTensor(y)
# Compute saliency maps for images in X
saliency = compute_saliency_maps(X_tensor, y_tensor, model)
# Convert the saliency map from Torch Tensor to numpy array and show images
# and saliency maps together.
saliency = saliency.numpy()
N = X.shape[0]
for i in range(N):
plt.subplot(2, N, i + 1)
plt.imshow(X[i])
plt.axis('off')
plt.title(class_names[y[i]])
plt.subplot(2, N, N + i + 1)
plt.imshow(saliency[i], cmap=plt.cm.hot)
plt.axis('off')
plt.gcf().set_size_inches(12, 5)
plt.show()
show_saliency_maps(X, y)