This numerical tour implements the method detailed in the paper of Gatys et al.. The implementation is intended to be as simple as possible, using Pytorch hooks to be applicable to any network (as opposed to the style transfer implementation).
This tour can be used as a gentle introduction to convolutional networks, where one can use a pre-trained network to perform a non-trivial vision task (involving in particular an optimization using back-propagation through the network).
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.models as models
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
Uncomment if you want to store to your own google drive.
# from google.colab import drive
# drive.mount('/content/drive')
Check if CUDA is available.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
cuda
Load the image $f \in \mathbb{R}^{3 \times n_0 \times n_0}$ where $3$ is the number of input channels.
from urllib.request import urlopen
import io
file_adress = 'https://raw.githubusercontent.com/leongatys/DeepTextures/master/Images/pebbles.jpg'
file_adress = 'http://graphics.stanford.edu/projects/texture/demo/texture_data/original/eero/radishes256.o.jpg'
file_adress = 'http://graphics.stanford.edu/projects/texture/demo/texture_data/original/eero/olives256.o.jpg'
file_adress = 'http://graphics.stanford.edu/projects/texture/demo/texture_data/original/eero/tomatoes256.o.jpg'
file_adress = 'http://graphics.stanford.edu/projects/texture/demo/texture_data/original/eero/yellow-peppers256.o.jpg'
fd = urlopen(file_adress)
image_file = io.BytesIO(fd.read())
f_pil = Image.open(image_file)
plt.imshow(f_pil)
plt.axis('off');
Image normalization (to fit network normalization during training) $$ \forall \ell \in \{0,1,2\}, \quad f[\ell,\cdot,\cdot] \leftarrow (f[\ell,\cdot,\cdot]-m_\ell)/\sigma_\ell $$ where $m$ and $\sigma$ are the empirical mean and standard deviation of the training set (here imagenet dataset).
n = 256
def normalize(f):
preprocess = transforms.Compose([
transforms.Resize(n),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
return torch.autograd.Variable(preprocess(f).unsqueeze(0).cuda())
def deprocess(image):
return image * torch.Tensor([0.229, 0.224, 0.225]).cuda() + torch.Tensor([0.485, 0.456, 0.406]).cuda()
Load a pre-trained (on imagenet) neural network architecture.
nn_type = 'resnet'
nn_type = 'vgg'
if nn_type=='vgg':
cnn = models.vgg19(pretrained=True)
elif nn_type=='resnet':
cnn = torch.hub.load('pytorch/vision:v0.6.0', 'resnet50', pretrained=True)
# no need to store the gradient
for param in cnn.parameters():
param.requires_grad = False
if torch.cuda.is_available():
cnn.cuda()
Starting from $f_0 = f \in \mathbb{R}^{m_0=3 \times n_0 \times n_0}$ the input image, the "feature" part of the newtork alternates layers of the form $$ f_{i+1} \equiv \Psi_i(f_i) \equiv [ \text{ReLu}( f_{i} \star w_i) ]_{\downarrow s_i} = \Psi_i \circ \Psi_{i-1} \circ \ldots \circ \Psi_0(f_0). $$ Here $f_{i} \in \mathbb{R}^{m_i \times n_i \times n_i}$ has $m_i$ channels and $n_i^2$ pixels (we assume square images for simplicity), $\star w_i[\ell,\ell',\cdot,\cdot]$ are the convolution filters, so that $$ \forall 0 \leq \ell < m_{i+1}, \quad ( f_{i} \star w_i )[ \ell,x,y ] = \sum_{\ell',x',,y'} w_i[\ell,\ell',x-x',y-y'] f_{i}[\ell',x',y']. $$ ReLu is the Rectified Liner Unit non-linearity ReLu$(s)=\max(s,0)$ and is implicitly applied to each element of a tensor.
The operation $[\cdot]_{\downarrow s_i}$ is a downsampling by a factor $s_i \in \{0,2\}$. If $s_i=0$, nothing is done (so that $n_{i+1}=n_i$), but if $s_i=2$, then the number of pixels is reduced by a factor $4$ and $n_{i+1}=n_i/2$. The most usual sub-sampling operator (when $s_i=2$) is the max-pooling, where $$ (A_{\downarrow 2})[\ell,x,y] \equiv \max(A[\ell,2x,2y],A[\ell,2x+1,2y],A[\ell,2x,2y+2],A[\ell,2x+1,2y+1]). $$
In the following, we denote $$ \Phi_i(f_0) \equiv f_i \quad \text{i.e.} \quad \Phi_i \equiv \Psi_{i-1} \circ \ldots \circ \Psi_0, $$ the map from the input image $f_0$ to the output of the ith layer of the network.
Note that for image classification tasks, this "feature" part is followed by a "classiciation" part, which is composed of a few fully connected (i.e. non-convolutive) layers and a final soft-max layer to output a probability vector among the classes of the dataset. During the training phase, a muli-class classification loss is minized by stochastic gradient descent to tune the convolution filters $(w_i)_i$. This part is assumed to already be done, and we do not optimize the weights $(w_i)_i$ in this tour.
We can display the architecture of the network, which has 16 feature layers, with only 5 having $s_i=2$ (pooling). This means that the final size of the image feature are $n_{16}=n_0/2^5$. The number of channel grows like $m_0=3,64, 128, 256, m_{16}=512$.
cnn
VGG( (features): Sequential( (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): ReLU(inplace=True) (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (3): ReLU(inplace=True) (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (6): ReLU(inplace=True) (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (8): ReLU(inplace=True) (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (11): ReLU(inplace=True) (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (13): ReLU(inplace=True) (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (15): ReLU(inplace=True) (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (17): ReLU(inplace=True) (18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (20): ReLU(inplace=True) (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (22): ReLU(inplace=True) (23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (24): ReLU(inplace=True) (25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (26): ReLU(inplace=True) (27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (29): ReLU(inplace=True) (30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (31): ReLU(inplace=True) (32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (33): ReLU(inplace=True) (34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (35): ReLU(inplace=True) (36): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) ) (avgpool): AdaptiveAvgPool2d(output_size=(7, 7)) (classifier): Sequential( (0): Linear(in_features=25088, out_features=4096, bias=True) (1): ReLU(inplace=True) (2): Dropout(p=0.5, inplace=False) (3): Linear(in_features=4096, out_features=4096, bias=True) (4): ReLU(inplace=True) (5): Dropout(p=0.5, inplace=False) (6): Linear(in_features=4096, out_features=1000, bias=True) ) )
Function to save activations when applying a network.
activation = {}
def get_activation(name):
def hook(model, input, output):
activation[name] = output #.detach()
return hook
Create a function to evaluate the network and retrieve a list $I$ of activations of some layers $(\Phi_i(f))_{i \in I}$.
# sub-select only a sub-set of output
if nn_type=='vgg':
I = range(0,37) # all the layers
I = [0] # first layer only
I = [36] # last layer only
I = [0, 4, 9, 18, 27] # first and after pooling
it = 0
for i in I:
cnn.features[i].register_forward_hook(get_activation(it))
it = it+1
elif nn_type=='resnet':
it = 0
cnn.conv1.register_forward_hook(get_activation(it)); it = it+1
cnn.layer1[2].register_forward_hook(get_activation(it)); it = it+1
cnn.layer2[3].register_forward_hook(get_activation(it)); it = it+1
cnn.layer3[5].register_forward_hook(get_activation(it)); it = it+1
cnn.layer4[2].register_forward_hook(get_activation(it)); it = it+1
if nn_type=='vgg':
for i in I:
print( cnn.features[i] )
Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
Apply the network, this will save the activations in the variable activation
.
f = normalize(f_pil)
cnn(f);
for a in activation:
print(activation[a].shape)
torch.Size([1, 64, 256, 256]) torch.Size([1, 64, 128, 128]) torch.Size([1, 128, 64, 64]) torch.Size([1, 256, 32, 32]) torch.Size([1, 512, 16, 16])
Display the first channel of each saved activations, i.e. $(f_i[0,\cdot,\cdot])_{i \in I}$.
for a in activation:
plt.subplot(2,3,a+1)
plt.imshow(activation[a].cpu()[0,0,:,:].squeeze())
Now display the total activation over the channels by summation, i.e. $(\sum_\ell f_i[\ell,\cdot,\cdot])_{i \in I}$.
for a in activation:
plt.subplot(2,3,a+1)
plt.imshow(torch.sum(activation[a].cpu(), axis=1).squeeze())
The general idea of statistical texture synthesis (as opposed to "copy-based" methods) is to draw a random noise image and then coerce it to enforce some empirical statistics to match those of the input one.
The initial idea appears in the early work of Heeger and Bergen which simply uses the histograms over a wavelet transform. This was refined by Zhu and Mumford, and by Portilla and Simoncelli, which uses more complex statistical descriptors (in particular higher order moments). The idea of Gatys' method is similar excepted it replaces the linear wavelet transform by a non-linear neural network.
In this neural network texture model, one only makes use of second order morment.
We denote $C(h) \in \mathbb{R}^{m \times m}$ the empirical covariance of a feature image $h \in \mathbb{R} ^{m \times n \times n}$ defined as $$ \forall 0 \leq \ell,\ell'<m, \quad C(h)[\ell,\ell'] \equiv \frac{1}{n^2} \sum_{x,y} h[\ell,x,y] h[\ell',x,y] . $$
class GramMatrix(nn.Module):
def forward(self, input):
b,c,h,w = input.size()
F = input.view(b, c, h*w)
G = torch.bmm(F, F.transpose(1,2))
G.div_(h*w) # normalize by the number of samples
return G
Compute the list of all Gram matrices (important: use detach
to return a new tensor) $( G(\Phi_i(f)) )_{i \in I}$.
f_gram = [GramMatrix()(activation[a]).detach() for a in activation]
Display those gram matrices (the log of their absolute values).
for a in activation:
plt.subplot(2,3,a+1)
U = f_gram[a].cpu().squeeze()
plt.imshow( torch.log(.01 + torch.abs(U)) )