Lesson 7: Super Resolution

In [1]:
from fastai.vision import *
from fastai.callbacks import *

from torchvision.models import vgg16_bn
In [2]:
path = untar_data(URLs.PETS)
path_hr = path / 'images'
path_lr = path / 'small-96'
path_mr = path / 'small-256'
In [3]:
il = ImageItemList.from_folder(path_hr)
In [4]:
def resize_one(fn, i):
    dest = path_lr / fn.relative_to(path_hr)
    dest.parent.mkdir(parents=True, exist_ok=True)
    img = PIL.Image.open(fn)
    targ_sz = resize_to(img, 96, use_min=True)
    img = img.resize(targ_sz, resample=PIL.Image.BILINEAR).convert('RGB')
    img.save(dest, quality=60)
In [5]:
# to create smaller images, uncomment the next line when you run this the first time
# parallel(resize_one, il.items)
100.00% [7390/7390 00:33<00:00]
In [4]:
bs, size = 32, 128
arch = models.resnet34

src = ImageImageList.from_folder(path_lr).random_split_by_pct(0.1, seed=9)
In [6]:
def get_data(bs, size):
    data = (src.label_from_func(lambda x: path_hr / x.name)
            .transform(get_transforms(max_zoom=2.), size=size, tfm_y=True)
            .databunch(bs=bs).normalize(imagenet_stats, do_y=True))
    data.c = 3
    return data
In [17]:
data = get_data(bs, size)
In [18]:
data.show_batch(ds_type=DatasetType.Valid, rows=2, figsize=(9,9))

Feature loss

In [23]:
t = data.valid_ds[0][1].data # get torch.Tensor from fastai.vision.image.Image
In [24]:
# Sanity check
t.shape
Out[24]:
torch.Size([3, 128, 128])
In [25]:
t = torch.stack([t, t]) # Concatenates sequence of tensors along a new dimension
In [26]:
# Sanity check
t.shape
Out[26]:
torch.Size([2, 3, 128, 128])
In [27]:
def gram_matrix(x):
    n, c, h, w = x.size() # n = 2, c = 3, h = 128, w = 128
    x = x.view(n, c, -1) # reshape tensor
    return (x @ x.transpose(1, 2)) / (c * h * w) # [2, 3, 16384] @ [2, 16384, 3]
In [37]:
# Sanity check
t_new = t.view(2, 3, -1)
t_new, t_new.shape
Out[37]:
(tensor([[[0.0611, 0.0543, 0.0630,  ..., 0.1063, 0.1047, 0.1100],
          [0.0611, 0.0543, 0.0630,  ..., 0.1063, 0.1047, 0.1100],
          [0.0611, 0.0543, 0.0630,  ..., 0.1063, 0.1047, 0.1100]],
 
         [[0.0611, 0.0543, 0.0630,  ..., 0.1063, 0.1047, 0.1100],
          [0.0611, 0.0543, 0.0630,  ..., 0.1063, 0.1047, 0.1100],
          [0.0611, 0.0543, 0.0630,  ..., 0.1063, 0.1047, 0.1100]]]),
 torch.Size([2, 3, 16384]))
In [41]:
t_new.transpose(1, 2).shape
Out[41]:
torch.Size([2, 16384, 3])
In [28]:
gram_matrix(t)
Out[28]:
tensor([[[0.1121, 0.1037, 0.0960],
         [0.1037, 0.0968, 0.0897],
         [0.0960, 0.0897, 0.0833]],

        [[0.1121, 0.1037, 0.0960],
         [0.1037, 0.0968, 0.0897],
         [0.0960, 0.0897, 0.0833]]])
In [42]:
base_loss = F.l1_loss
In [43]:
vgg_m = vgg16_bn(True).features.cuda().eval()
requires_grad(vgg_m, False)

Get layers from VGG16 network. Get a layer before MaxPool2d layer which is ReLU.

In [44]:
blocks = [i - 1 for i, o in enumerate(children(vgg_m)) if isinstance(o, nn.MaxPool2d)]
In [51]:
blocks, [vgg_m[i] for i in blocks]
Out[51]:
([5, 12, 22, 32, 42],
 [ReLU(inplace), ReLU(inplace), ReLU(inplace), ReLU(inplace), ReLU(inplace)])
In [53]:
class FeatureLoss(nn.Module):
    def __init__(self, m_feat, layer_ids, layer_wgts):
        super().__init__()
        self.m_feat = m_feat
        self.loss_features = [self.m_feat[i] for i in layer_ids]
        self.hooks = hook_outputs(self.loss_features, detach=False)
        self.wgts = layer_wgts
        self.metric_names = ['pixel',] + [f'feat_{i}' for i in range(len(layer_ids))
                ] + [f'gram_{i}' for i in range(len(layer_ids))]
        
    def make_features(self, x, clone=False):
        self.m_feat(x)
        return [(o.clone() if clone else o) for o in self.hooks.stored]
    
    def forward(self, input, target):
        out_feat = self.make_features(target, clone=True)
        in_feat = self.make_features(input)
        self.feat_losses = [base_loss(input, target)]
        self.feat_losses += [base_loss(f_in, f_out)*w
                             for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
        self.feat_losses += [base_loss(gram_matrix(f_in), gram_matrix(f_out))*w**2 * 5e3
                             for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
        self.metrics = dict(zip(self.metric_names, self.feat_losses))
        return sum(self.feat_losses)
    
    def __del__(self): self.hooks.remove()
In [55]:
feat_loss = FeatureLoss(vgg_m, blocks[2:5], [5,15,2])

Train

In [56]:
wd = 1e-3
learn = unet_learner(data, arch, wd=wd, loss_func=feat_loss, callback_fns=LossMetrics,
                     blur=True, norm_type=NormType.Weight)
gc.collect()
Out[56]:
8002
In [57]:
learn.lr_find()
learn.recorder.plot()
LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.
In [58]:
lr = 1e-3
In [59]:
# utility function
def do_fit(save_name, lrs=slice(lr), pct_start=0.9):
    learn.fit_one_cycle(10, lrs, pct_start=pct_start)
    learn.save(save_name)
    learn.show_results(rows=1, imgsize=5)
In [60]:
do_fit('superres_1a', slice(lr*10))
Total time: 40:39

epoch train_loss valid_loss pixel feat_0 feat_1 feat_2 gram_0 gram_1 gram_2
1 3.878758 3.765364 0.143471 0.231686 0.318524 0.225756 0.569088 1.224830 1.052009
2 3.746968 3.623530 0.146915 0.228077 0.310260 0.217326 0.531060 1.177929 1.011963
3 3.683856 3.542798 0.144409 0.225192 0.304307 0.212406 0.523345 1.143946 0.989192
4 3.625614 3.495890 0.148949 0.225164 0.301398 0.209922 0.506778 1.128076 0.975601
5 3.575768 3.412124 0.142050 0.220418 0.293528 0.205400 0.488623 1.102801 0.959304
6 3.530985 3.363271 0.144510 0.222042 0.292827 0.200777 0.482796 1.082339 0.937979
7 3.483688 3.300640 0.140738 0.220030 0.288998 0.198624 0.460092 1.063068 0.929091
8 3.479184 3.300606 0.138870 0.219142 0.287727 0.197092 0.472708 1.066907 0.918159
9 3.437450 3.271332 0.144248 0.220011 0.287269 0.196026 0.453093 1.057170 0.913514
10 3.353931 3.148826 0.136380 0.215376 0.278493 0.188945 0.429711 1.018132 0.881789
In [61]:
learn.unfreeze()
In [62]:
do_fit('superres_1b', slice(1e-5, lr))
Total time: 42:01

epoch train_loss valid_loss pixel feat_0 feat_1 feat_2 gram_0 gram_1 gram_2
1 3.323620 3.144714 0.135964 0.215112 0.278066 0.188430 0.430221 1.017568 0.879353
2 3.299178 3.139331 0.136003 0.215047 0.277739 0.188079 0.428874 1.016128 0.877462
3 3.304949 3.134083 0.135695 0.214730 0.277511 0.187799 0.427655 1.013973 0.876721
4 3.288480 3.132256 0.135890 0.215035 0.277453 0.187660 0.427123 1.013997 0.875097
5 3.303428 3.128728 0.136634 0.214946 0.277462 0.187666 0.424878 1.012343 0.874799
6 3.289925 3.122605 0.136543 0.214064 0.276433 0.187116 0.425160 1.010485 0.872804
7 3.292779 3.113085 0.135635 0.213773 0.276006 0.186685 0.423021 1.007254 0.870709
8 3.276536 3.104526 0.135603 0.213475 0.275559 0.185891 0.421933 1.005491 0.866574
9 3.262141 3.108630 0.136470 0.213704 0.275759 0.186213 0.420140 1.007921 0.868424
10 3.274122 3.088803 0.135310 0.213197 0.274378 0.185061 0.417686 1.000610 0.862560
In [7]:
data = get_data(12, size*2)
In [64]:
learn.data = data
learn.freeze()
gc.collect()
Out[64]:
18621
In [65]:
learn.load('superres_1b')
Out[65]:
Learner(data=ImageDataBunch;

Train: LabelList
y: ImageItemList (6651 items)
[Image (3, 331, 500), Image (3, 500, 335), Image (3, 375, 500), Image (3, 500, 333), Image (3, 375, 500)]...
Path: /home/ubuntu/.fastai/data/oxford-iiit-pet/small-96
x: ImageImageList (6651 items)
[Image (3, 96, 145), Image (3, 143, 96), Image (3, 96, 128), Image (3, 144, 96), Image (3, 96, 128)]...
Path: /home/ubuntu/.fastai/data/oxford-iiit-pet/small-96;

Valid: LabelList
y: ImageItemList (739 items)
[Image (3, 500, 333), Image (3, 500, 344), Image (3, 375, 500), Image (3, 500, 333), Image (3, 500, 375)]...
Path: /home/ubuntu/.fastai/data/oxford-iiit-pet/small-96
x: ImageImageList (739 items)
[Image (3, 144, 96), Image (3, 139, 96), Image (3, 96, 128), Image (3, 144, 96), Image (3, 128, 96)]...
Path: /home/ubuntu/.fastai/data/oxford-iiit-pet/small-96;

Test: None, model=DynamicUnet(
  (layers): ModuleList(
    (0): Sequential(
      (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace)
      (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (4): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (2): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (5): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (downsample): Sequential(
            (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): BasicBlock(
          (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (2): BasicBlock(
          (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (3): BasicBlock(
          (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (6): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (downsample): Sequential(
            (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): BasicBlock(
          (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (2): BasicBlock(
          (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (3): BasicBlock(
          (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (4): BasicBlock(
          (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (5): BasicBlock(
          (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (7): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (downsample): Sequential(
            (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): BasicBlock(
          (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (2): BasicBlock(
          (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
    )
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Sequential(
      (0): Sequential(
        (0): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace)
      )
      (1): Sequential(
        (0): Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace)
      )
    )
    (4): UnetBlock(
      (shuf): PixelShuffle_ICNR(
        (conv): Sequential(
          (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1))
        )
        (shuf): PixelShuffle(upscale_factor=2)
        (pad): ReplicationPad2d((1, 0, 1, 0))
        (blur): AvgPool2d(kernel_size=2, stride=1, padding=0)
        (relu): ReLU(inplace)
      )
      (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv1): Sequential(
        (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace)
      )
      (conv2): Sequential(
        (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace)
      )
      (relu): ReLU()
    )
    (5): UnetBlock(
      (shuf): PixelShuffle_ICNR(
        (conv): Sequential(
          (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1))
        )
        (shuf): PixelShuffle(upscale_factor=2)
        (pad): ReplicationPad2d((1, 0, 1, 0))
        (blur): AvgPool2d(kernel_size=2, stride=1, padding=0)
        (relu): ReLU(inplace)
      )
      (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv1): Sequential(
        (0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace)
      )
      (conv2): Sequential(
        (0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace)
      )
      (relu): ReLU()
    )
    (6): UnetBlock(
      (shuf): PixelShuffle_ICNR(
        (conv): Sequential(
          (0): Conv2d(384, 768, kernel_size=(1, 1), stride=(1, 1))
        )
        (shuf): PixelShuffle(upscale_factor=2)
        (pad): ReplicationPad2d((1, 0, 1, 0))
        (blur): AvgPool2d(kernel_size=2, stride=1, padding=0)
        (relu): ReLU(inplace)
      )
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv1): Sequential(
        (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace)
      )
      (conv2): Sequential(
        (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace)
      )
      (relu): ReLU()
    )
    (7): UnetBlock(
      (shuf): PixelShuffle_ICNR(
        (conv): Sequential(
          (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1))
        )
        (shuf): PixelShuffle(upscale_factor=2)
        (pad): ReplicationPad2d((1, 0, 1, 0))
        (blur): AvgPool2d(kernel_size=2, stride=1, padding=0)
        (relu): ReLU(inplace)
      )
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv1): Sequential(
        (0): Conv2d(192, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace)
      )
      (conv2): Sequential(
        (0): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace)
      )
      (relu): ReLU()
    )
    (8): PixelShuffle_ICNR(
      (conv): Sequential(
        (0): Conv2d(96, 384, kernel_size=(1, 1), stride=(1, 1))
      )
      (shuf): PixelShuffle(upscale_factor=2)
      (pad): ReplicationPad2d((1, 0, 1, 0))
      (blur): AvgPool2d(kernel_size=2, stride=1, padding=0)
      (relu): ReLU(inplace)
    )
    (9): MergeLayer()
    (10): SequentialEx(
      (layers): ModuleList(
        (0): Sequential(
          (0): Conv2d(99, 99, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): ReLU(inplace)
        )
        (1): Sequential(
          (0): Conv2d(99, 99, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): ReLU(inplace)
        )
        (2): MergeLayer()
      )
    )
    (11): Sequential(
      (0): Conv2d(99, 3, kernel_size=(1, 1), stride=(1, 1))
    )
  )
), opt_func=functools.partial(<class 'torch.optim.adam.Adam'>, betas=(0.9, 0.99)), loss_func=FeatureLoss(
  (m_feat): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace)
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ReLU(inplace)
    (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): ReLU(inplace)
    (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (14): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (16): ReLU(inplace)
    (17): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (18): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (19): ReLU(inplace)
    (20): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (21): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (22): ReLU(inplace)
    (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (24): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (25): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (26): ReLU(inplace)
    (27): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (28): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (29): ReLU(inplace)
    (30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (31): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (32): ReLU(inplace)
    (33): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (35): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (36): ReLU(inplace)
    (37): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (38): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (39): ReLU(inplace)
    (40): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (41): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (42): ReLU(inplace)
    (43): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
), metrics=[], true_wd=True, bn_wd=True, wd=0.001, train_bn=True, path=PosixPath('/home/ubuntu/.fastai/data/oxford-iiit-pet/small-96'), model_dir='models', callback_fns=[<class 'fastai.basic_train.Recorder'>, <class 'fastai.callbacks.loss_metrics.LossMetrics'>], callbacks=[], layer_groups=[Sequential(
  (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace)
  (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (6): ReLU(inplace)
  (7): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (8): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (9): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (11): ReLU(inplace)
  (12): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (13): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (14): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (15): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (16): ReLU(inplace)
  (17): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (18): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (19): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (20): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (21): ReLU(inplace)
  (22): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (23): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (24): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
  (25): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (26): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (27): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (28): ReLU(inplace)
  (29): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (30): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (31): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (32): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (33): ReLU(inplace)
  (34): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (35): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (36): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (37): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (38): ReLU(inplace)
  (39): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (40): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
), Sequential(
  (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace)
  (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (5): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
  (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (7): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (8): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (9): ReLU(inplace)
  (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (11): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (13): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (14): ReLU(inplace)
  (15): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (16): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (17): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (18): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (19): ReLU(inplace)
  (20): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (21): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (22): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (23): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (24): ReLU(inplace)
  (25): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (26): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (27): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (28): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (29): ReLU(inplace)
  (30): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (31): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (32): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (33): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (34): ReLU(inplace)
  (35): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (36): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (37): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
  (38): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (39): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (40): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (41): ReLU(inplace)
  (42): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (43): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (44): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (45): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (46): ReLU(inplace)
  (47): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (48): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
), Sequential(
  (0): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (1): ReLU()
  (2): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (3): ReLU(inplace)
  (4): Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (5): ReLU(inplace)
  (6): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1))
  (7): PixelShuffle(upscale_factor=2)
  (8): ReplicationPad2d((1, 0, 1, 0))
  (9): AvgPool2d(kernel_size=2, stride=1, padding=0)
  (10): ReLU(inplace)
  (11): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (12): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (13): ReLU(inplace)
  (14): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (15): ReLU(inplace)
  (16): ReLU()
  (17): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1))
  (18): PixelShuffle(upscale_factor=2)
  (19): ReplicationPad2d((1, 0, 1, 0))
  (20): AvgPool2d(kernel_size=2, stride=1, padding=0)
  (21): ReLU(inplace)
  (22): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (23): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (24): ReLU(inplace)
  (25): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (26): ReLU(inplace)
  (27): ReLU()
  (28): Conv2d(384, 768, kernel_size=(1, 1), stride=(1, 1))
  (29): PixelShuffle(upscale_factor=2)
  (30): ReplicationPad2d((1, 0, 1, 0))
  (31): AvgPool2d(kernel_size=2, stride=1, padding=0)
  (32): ReLU(inplace)
  (33): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (34): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (35): ReLU(inplace)
  (36): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (37): ReLU(inplace)
  (38): ReLU()
  (39): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1))
  (40): PixelShuffle(upscale_factor=2)
  (41): ReplicationPad2d((1, 0, 1, 0))
  (42): AvgPool2d(kernel_size=2, stride=1, padding=0)
  (43): ReLU(inplace)
  (44): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (45): Conv2d(192, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (46): ReLU(inplace)
  (47): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (48): ReLU(inplace)
  (49): ReLU()
  (50): Conv2d(96, 384, kernel_size=(1, 1), stride=(1, 1))
  (51): PixelShuffle(upscale_factor=2)
  (52): ReplicationPad2d((1, 0, 1, 0))
  (53): AvgPool2d(kernel_size=2, stride=1, padding=0)
  (54): ReLU(inplace)
  (55): MergeLayer()
  (56): Conv2d(99, 99, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (57): ReLU(inplace)
  (58): Conv2d(99, 99, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (59): ReLU(inplace)
  (60): MergeLayer()
  (61): Conv2d(99, 3, kernel_size=(1, 1), stride=(1, 1))
)])
In [66]:
do_fit('superres_2a')
Total time: 2:52:07

epoch train_loss valid_loss pixel feat_0 feat_1 feat_2 gram_0 gram_1 gram_2
1 2.223401 2.208069 0.163590 0.258079 0.294132 0.155751 0.390485 0.582918 0.363114
2 2.207594 2.177671 0.164055 0.259362 0.293865 0.154717 0.373807 0.573298 0.358567
3 2.191846 2.158177 0.166473 0.260517 0.293283 0.154534 0.360691 0.566488 0.356189
4 2.164305 2.146150 0.166441 0.258691 0.290845 0.152078 0.360703 0.564724 0.352669
5 2.135329 2.135777 0.167998 0.259458 0.289835 0.152016 0.354082 0.561161 0.351228
6 2.139521 2.116349 0.167046 0.257672 0.287944 0.150504 0.347994 0.555875 0.349314
7 2.102883 2.111770 0.167683 0.257555 0.286768 0.150004 0.348646 0.553397 0.347716
8 2.111377 2.101841 0.165881 0.257276 0.286179 0.149490 0.345602 0.552041 0.345372
9 2.087749 2.092504 0.165618 0.257330 0.285525 0.149661 0.339318 0.548332 0.346720
10 2.070814 2.073214 0.166590 0.257045 0.283973 0.147354 0.333878 0.542669 0.341706
In [67]:
learn.unfreeze()
In [68]:
do_fit('superres_2b', slice(1e-6, 1e-4), pct_start=0.3)
Total time: 2:57:25

epoch train_loss valid_loss pixel feat_0 feat_1 feat_2 gram_0 gram_1 gram_2
1 2.083004 2.072233 0.166136 0.256832 0.283742 0.147315 0.334218 0.542541 0.341448
2 2.055553 2.069016 0.165861 0.256336 0.283636 0.147269 0.333190 0.541869 0.340853
3 2.065424 2.067941 0.166512 0.256510 0.283271 0.146918 0.333683 0.541213 0.339833
4 2.063064 2.063793 0.165930 0.256200 0.283104 0.146570 0.332284 0.540426 0.339278
5 2.068301 2.062547 0.166182 0.256659 0.283385 0.146777 0.330254 0.539938 0.339351
6 2.058124 2.058460 0.165859 0.255868 0.282628 0.146286 0.331050 0.538475 0.338294
7 2.055563 2.058390 0.166337 0.256425 0.282894 0.146360 0.329987 0.538086 0.338300
8 2.059268 2.055973 0.166258 0.256073 0.282665 0.146101 0.329223 0.537654 0.337999
9 2.046954 2.055591 0.166408 0.256502 0.283080 0.146351 0.327730 0.537261 0.338258
10 2.049621 2.055515 0.166337 0.256031 0.282606 0.146133 0.329076 0.537291 0.338040