Neural Algorithm of Artistic Style Transfer

In [1]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

Setup

Import libraries

In [2]:
from fastai.conv_learner import *
from pathlib import Path
from scipy import ndimage

# torch.cuda.set_device(0)
torch.backends.cudnn.benchmark = True

Setup directory and file paths

In [3]:
PATH = Path('data/imagenet')
PATH_TRN = PATH / 'train'

Initialize pre-trained VGG model

In [4]:
m_vgg = to_gpu(vgg16(True)).eval()
set_trainable(m_vgg, False)

Grab and show an image from ImageNet

In [7]:
img_fn = PATH_TRN / 'n01558993' / 'n01558993_9684.JPEG'
img = open_image(img_fn)
plt.imshow(img)
Out[7]:
<matplotlib.image.AxesImage at 0x7f9ac2577940>

Content loss

Let's start by trying to create a bird that initially is random noise and we are going to use perceptual loss (content loss) to create something that is bird-like but it's not the particular bird.

In [8]:
sz = 288

So we are going to stick it through some transforms for VGG16 model.

In [9]:
trn_tfms, val_tfms = tfms_from_model(vgg16, sz)
img_tfm = val_tfms(img)
img_tfm.shape
Out[9]:
(3, 288, 288)

Now we have something of 3 by 288 by 288 because PyTorch likes the channel to be first. As you can see, it's been turned into a square for us, it's been normalized to (0, 1), all that normal stuff.

We are going to start with this picture:

Now we are creating a random image.

In [10]:
opt_img = np.random.uniform(0, 1, size=img.shape).astype(np.float32)
plt.imshow(opt_img)
Out[10]:
<matplotlib.image.AxesImage at 0x7f9ac24f89e8>

Trying to turn this into a picture of anything is actually really hard. I found it very difficult to actually get an optimizer to get reasonable gradients that went anywhere. And just as I thought I was going to run out of time for this class and really embarrass myself, I realized the key issue is that pictures don't look like this. They have more smoothness, so I turned this into the following by blurring it a little bit:

In [11]:
opt_img = scipy.ndimage.filters.median_filter(opt_img, [8, 8, 1])
In [12]:
plt.imshow(opt_img)
Out[12]:
<matplotlib.image.AxesImage at 0x7f9ac244e278>

As soon as we change it to this, it immediately started training really well. A number of little tweaks you have to do to get these things to work is kind of insane, but here is a little tweak.

So we start with a random image which is at least somewhat smooth. We found that my bird image had a mean of pixels that was about half of this, so we divided it by 2 just trying to make it a little bit easier for it to match (we don't know if it matters). Turn that into a variable because this image, remember, we are going to be modifying those pixels with an optimization algorithm, so anything that's involved in the loss function needs to be a variable. And specifically, it requires a gradient because we are actually updating the image.

In [13]:
opt_img = val_tfms(opt_img) / 2
opt_img_v = V(opt_img[None], requires_grad=True)
opt_img_v.shape
Out[13]:
torch.Size([1, 3, 288, 288])

So we now have a mini batch of 1, 3 channels, 288 by 288 random noise.

We are going to use, for no particular reason, the 37th layer of VGG. If you print out the VGG network (you can just type in m_vgg and prints it out), you'll see that this is mid to late stage layer. So we can just grab the first 37 layers and turn it into a sequential model. So now we have a subset of VGG that will spit out some mid layer activations, and that's what the model is going to be.

In [14]:
m_vgg = nn.Sequential(*children(m_vgg)[:37])

So now we can take our actual bird image and we want to create a mini batch of one. Remember, if you slice in NumPy with None, also known as np.newaxis, it introduces a new unit axis in that point. Here, I want to create an axis of size 1 to say this is a mini batch of size one. So slicing with None just like we did here to get one unit axis at the front. Then we turn that into a variable and this one doesn't need to be updated, so we use VV to say you don't need gradients for this guy. So that is going to give us our target activations.

In [15]:
targ_t = m_vgg(VV(img_tfm[None]))
targ_v = V(targ_t)
targ_t.shape
Out[15]:
torch.Size([1, 512, 18, 18])

Limited memory Broyden–Fletcher–Goldfarb–Shanno (LBFGS)

In [16]:
max_iter = 1000
show_iter = 100
optimizer = optim.LBFGS([opt_img_v], lr=0.5)
In [17]:
def actn_loss(x):
    return F.mse_loss(m_vgg(x), targ_v) * 1000
In [18]:
def step(loss_fn):
    global n_iter
    optimizer.zero_grad()
    # passing in that randomly generated image — the variable of optimization image to the loss function
    loss = loss_fn(opt_img_v)
    loss.backward()
    n_iter += 1
    if n_iter % show_iter == 0:
        print(f'Iteration: n_iter, loss: {loss.data[0]}')
    return loss

As you can see here, when you say optimizer.step, you actually pass in the loss function. So our loss function is to call step with a particular loss function which is our activation loss (actn_loss).

We run that bunch of times and we'll print it out. And we have our bird but not the representation of it.

In [19]:
n_iter = 0
while n_iter <= max_iter:
    optimizer.step(partial(step, actn_loss))
Iteration: n_iter, loss: 0.8200027942657471
Iteration: n_iter, loss: 0.3576483130455017
Iteration: n_iter, loss: 0.23157010972499847
Iteration: n_iter, loss: 0.17518416047096252
Iteration: n_iter, loss: 0.14312393963336945
Iteration: n_iter, loss: 0.1230238527059555
Iteration: n_iter, loss: 0.10892671346664429
Iteration: n_iter, loss: 0.09870683401823044
Iteration: n_iter, loss: 0.09066757559776306
Iteration: n_iter, loss: 0.08464114367961884

So you can see the loss function going down. The mean squared error between the activations at layer 37 of our VGG model for our optimized image vs. the target activations, remember the target activations were the VGG applied to our bird.

In [20]:
x = val_tfms.denorm(np.rollaxis(to_np(opt_img_v.data), 1, 4))[0]
plt.figure(figsize=(7, 7))
plt.imshow(x)
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Out[20]:
<matplotlib.image.AxesImage at 0x7f9ac242eb38>