Super Resolution

In [1]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

Data

Initial setup

Import libraries:

In [2]:
from fastai.conv_learner import *
from pathlib import Path
# torch.cuda.set_device(0)

torch.backends.cudnn.benchmark = True

Define directory and file paths:

In [3]:
PATH = Path('data/imagenet')
PATH_TRN = PATH / 'train'

We don't really have labels per se, so I'm just going to give everything a label of zero just so we can use it with our existing infrastructure more easily.

In [4]:
fnames_full, label_arr_full, all_labels = folder_source(PATH, 'train')
fnames_full = ['/'.join(Path(fn).parts[-2:]) for fn in fnames_full]
list(zip(fnames_full[:5], label_arr_full[:5]))
Out[4]:
[('n01440764/n01440764_12241.JPEG', 0),
 ('n01440764/n01440764_529.JPEG', 0),
 ('n01440764/n01440764_11155.JPEG', 0),
 ('n01440764/n01440764_9649.JPEG', 0),
 ('n01440764/n01440764_8013.JPEG', 0)]
In [5]:
all_labels[:5]
Out[5]:
['n01440764', 'n01443537', 'n01491361', 'n01494475', 'n01498041']

Now, because I'm pointing at a folder that contains all of ImageNet, I certainly don't want to wait for all of ImageNet to finish to run an epoch. So here, I'm just, most of the time, I would set "keep percent" (keep_pct) to 1 or 2%. And then I just generate a bunch of random numbers and then I just keep those which are less than 0.02 and so that lets me quickly subsample my rows.

In [6]:
np.random.seed(42)
keep_pct = 1.
# keep_pct = 0.02
keeps = np.random.rand(len(fnames_full)) < keep_pct
fnames = np.array(fnames_full, copy=False)[keeps]
label_arr = np.array(label_arr_full, copy=False)[keeps]

Network architecture

Backbone

We are going to use VGG today even though it's ancient and it's missing lots of great stuff. One thing we are going to do though is use a slightly more modern version which is a version of VGG where batch norm has been added after all the convolutions. In fast.ai when you ask for a VGG network, you always get the batch norm one because that's basically always what you want.

In [7]:
arch = vgg16

# We are going to go from 72 by 72 low resolution
sz_lr = 72 # size low resolution

Stage 1

We are going to initially scale it up by times 2 with the batch size of 64 to get 2 * 72, so 144 by 144 output. That is going to be our stage one.

In [8]:
scale, bs = 2, 64
# scale, bs = 4, 32
sz_hr = sz_lr * scale

Create our own dataset

I want a dataset where my x's are images and my y's are also images. There's already a files dataset inside the fastai.dataset module we can inherit from where the x's are images and then I just inherit from that and I just copied and pasted the get_x and turn that into get_y so it just opens an image. What we're passing in is an array of files names.

In [9]:
class MatchedFilesDataset(FilesDataset):
    def __init__(self, fnames, y, transform, path):
        self.y = y
        assert(len(fnames) == len(y))
        super().__init__(fnames, transform, path)
        
    def get_y(self, i): return open_image(os.path.join(self.path, self.y[i]))
    
    def get_c(self): return 0

Data augmentation

RandomDihedral is referring to every possible 90 degree rotation plus optional left/right flipping so they are dihedral group of eight symmetries. Normally we don't use this transformation for ImageNet pictures because you don't normally flip dogs upside down but in this case, we are not trying to classify whether it's a dog or a cat, we are just trying to keep the general structure of it. So actually every possible flip is a reasonably sensible thing to do for this problem.

In [10]:
aug_tfms = [RandomDihedral(tfm_y=TfmType.PIXEL)]
In [11]:
val_idxs = get_cv_idxs(len(fnames), val_pct=min(0.01 / keep_pct, 0.1))
((val_x, trn_x), (val_y, trn_y)) = split_by_idx(val_idxs, np.array(fnames), np.array(fnames))
len(val_x), len(trn_x)
Out[11]:
(194, 19245)
In [12]:
img_fn = PATH / 'train' / 'n01558993' / 'n01558993_9684.JPEG'

Create Transformations

We are going to use tfm_y parameter like we did for bounding boxes but rather than use TfmType.COORD we are going to use TfmType.PIXEL. That tells our transformations framework that your y values are images with normal pixels in them, so anything you do to the x, you also need to do the same thing to the y. You need to make sure any data augmentation transformations you use have the same parameter as well.

In [13]:
tfms = tfms_from_model(arch, sz_lr, tfm_y=TfmType.PIXEL, aug_tfms=aug_tfms, sz_y=sz_hr)
datasets = ImageData.get_ds(MatchedFilesDataset, (trn_x, trn_y), (val_x, val_y), tfms, path=PATH_TRN)
md = ImageData(PATH, datasets, bs, num_workers=16, classes=None)

There is a handy little method called get datasets (get_ds) which basically runs that constructor over all the different things that you have to return all the datasets you need in exactly the right format to pass to a ModelData constructor (in this case the ImageData constructor).

Look at an image from dataset

If we want to be able to display those pictures that have come out of our datasets or data loaders, we need to de-normalize them.

In [14]:
denorm = md.val_ds.denorm

A helper function that can show an image from a dataset and if you pass in something saying this is a normalized image, then we'll denorm it.

In [15]:
def show_img(ims, idx, figsize=(5,5), normed=True, ax=None):
    if ax is None:
        fig, ax = plt.subplots(figsize=figsize)
    if normed:
        ims = denorm(ims)
    else:
        ims = np.rollaxis(to_np(ims), 1, 4)
    ax.imshow(np.clip(ims, 0, 1)[idx])
    ax.axis('off')
In [16]:
x, y = next(iter(md.val_dl))
x.size(), y.size()
Out[16]:
(torch.Size([32, 3, 72, 72]), torch.Size([32, 3, 288, 288]))

Here you can see the two different resolutions of our x and our y for a bird.

In [18]:
idx = 1
fig, axes = plt.subplots(1, 2, figsize=(9, 5))
show_img(x, idx, ax=axes[0])
show_img(y, idx, ax=axes[1])

Next, let's have a look at a few different versions of the data transformation. There you can see them being flipped in all different directions.

In [19]:
batches = [next(iter(md.aug_dl)) for i in range(9)]
In [20]:
fig, axes = plt.subplots(3, 6, figsize=(18, 9))

for i,(x, y) in enumerate(batches):
    show_img(x, idx, ax=axes.flat[i*2])
    show_img(y, idx, ax=axes.flat[i*2+1])

Model

We are going to have a small image coming in, and we want to have a big image coming out. So we need to do some computation between those two to calculate what the big image would look like.

The way of doing that computation is, first do lots of stride one layers to do all the computation and then at the end do some upsampling.

We are going to create something with five ResNet blocks and then for each 2x scale up we have to do, we'll have one upsampling block.

In [21]:
def conv(ni, nf, kernel_size=3, actn=False):
    """Standard convolution block"""
    layers = [nn.Conv2d(ni, nf, kernel_size, padding=kernel_size//2)]
    if actn:
        layers.append(nn.ReLU(True))
    return nn.Sequential(*layers)

EDSR and SRResNet idea

One interesting thing about our little conv block is that there is no batch norm which is pretty unusual for ResNet type models.

The reason there is no batch norm is because we are borrowing ideas from this fantastic paper which actually won a recent competition in super resolution performance.

So this paper was a really big step-up. They call their model EDSR (Enhanced Deep Super-Resolution network) and they did two things differently to the previous standard approaches:

  1. Take the ResNet blocks and throw away the batch norms.
  2. Scaling factor, res_scale.
In [22]:
class ResSequential(nn.Module):
    def __init__(self, layers, res_scale=1.0):
        super().__init__()
        self.res_scale = res_scale
        self.m = nn.Sequential(*layers)
        
    def forward(self, x):
        return x + self.m(x) * self.res_scale

So we are going to create a residual block containing two convolutions. As you see in their approach, they don't even have a ReLU after their second conv. So that's why I've only got activation on the first one.

In [23]:
def res_block(nf):
    return ResSequential(
        [conv(nf, nf, actn=True), conv(nf, nf)],
        0.1)

A couple of interesting things here. One is that this idea of having some kind of a main ResNet path (conv, ReLU, conv) and then turning that into a ReLU block by adding it back to the identity — it's something we do so often that I factored it out into a tiny little module called ResSequential.

What's res_scale? res_scale is the number 0.1. Why is it there?

Christian Szegedy who invented batch norm also somewhat more recently did a paper in which he showed for (I think) the first time the ability to train ImageNet in under an hour. Something Christian found was that in the ResNet blocks, if he multiplied them by some number smaller than 1, something like .1 or .2, it really helped stabilize training at the start.

In our case, we are just toning things down based on our initial initialization.

In [24]:
def upsample(ni, nf, scale):
    layers = []
    for i in range(int(math.log(scale, 2))):
        layers += [conv(ni, nf*4), nn.PixelShuffle(2)]
    return nn.Sequential(*layers)

So basically our super-resolution ResNet (SrResnet) is going to do a convolution to go from our three channels to 64 channels just to richen up the space a little bit. Then also we've got actually 8 not 5 Res blocks res_block. Remember, every one of these Res block is stride 1 so the grid size doesn't change, the number of filters doesn't change. It's just 64 all the way through. We'll do one more convolution, and then we'll do our upsampling by however much scale we asked for. Then something I've added which is one batch norm here because it felt like it might be helpful just to scale the last layer. Then finally conv to go back to the three channels we want.

In [25]:
class SrResnet(nn.Module):
    def __init__(self, nf, scale):
        super().__init__()
        features = [conv(3, 64)]
        
        for i in range(8):
            features.append(res_block(64))

        features += [
            conv(64, 64),
            upsample(64, 64, scale),
            nn.BatchNorm2d(64),
            conv(64, 3)
        ]
        self.features = nn.Sequential(*features)

    def forward(self, x):
        return self.features(x)

The upsampling is a bit interesting because it is not doing either of two things (transposed or fractionally strided convolutions or nearest neighbor upsampling followed by a 1x1 conv) because that will cause the checkerboard patterns.

So instead, we are going to do the pixel shuffle. Pixel shuffle is an operation in this sub-pixel convolutional neural network. Refer to Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network paper to find out more.

Pixel loss

Initialize the model and sent it to GPU:

In [26]:
m = to_gpu(SrResnet(64, scale))

To make life faster, we are going to run things in parallel.

Take your PyTorch module and wrap it with DataParallel. Once you've done that, it copies it to each of your GPUs and will automatically run it in parallel. It scales pretty well to two GPUs, okay to three GPUs, better than nothing to four GPUs and beyond that, performance does go backwards. Note, by default, it will copy it to all of your GPUs.

In [ ]:
# Uncomment this line if you have more than 1 GPU.
# m = nn.DataParallel(m, [0, 2])

We create our learner the usual way.

We can use MSE loss here so that's just going to compare the pixels of the output to the pixels that we expected.

In [27]:
learn = Learner(md, SingleModel(m), opt_fn=optim.Adam)
learn.crit = F.mse_loss

We can run our learning rate finder and we can train it for a while.

In [28]:
learn.lr_find(start_lr=1e-5, end_lr=10000)
learn.sched.plot(10, 0)
 30%|███       | 183/602 [03:10<07:17,  1.04s/it, loss=9.88] 
In [29]:
lr = 2e-3
In [30]:
learn.fit(lr, 1, cycle_len=1, use_clr_beta=(40, 10))
epoch      trn_loss   val_loss                               
    0      0.103036   0.09909   
Out[30]:
[array([0.09909])]
In [31]:
x, y = next(iter(md.val_dl))
preds = learn.model(VV(x))

Ground truth image (high-res)

In [44]:
idx = 1
show_img(y, idx, normed=True)