%matplotlib inline
%reload_ext autoreload
%autoreload 2
from fastai.conv_learner import *
from pathlib import Path
torch.cuda.set_device(0)
torch.backends.cudnn.benchmark=True
PATH = Path('data/imagenet')
PATH_TRN = PATH/'train'
fnames_full,label_arr_full,all_labels = folder_source(PATH, 'train')
fnames_full = ['/'.join(Path(fn).parts[-2:]) for fn in fnames_full]
list(zip(fnames_full[:5],label_arr_full[:5]))
[('n01440764/n01440764_9627.JPEG', 0), ('n01440764/n01440764_9609.JPEG', 0), ('n01440764/n01440764_5176.JPEG', 0), ('n01440764/n01440764_6936.JPEG', 0), ('n01440764/n01440764_4005.JPEG', 0)]
all_labels[:5]
['n01440764', 'n01443537', 'n01484850', 'n01491361', 'n01494475']
np.random.seed(42)
keep_pct = 1.
# keep_pct = 0.02
keeps = np.random.rand(len(fnames_full)) < keep_pct
fnames = np.array(fnames_full, copy=False)[keeps]
label_arr = np.array(label_arr_full, copy=False)[keeps]
arch = vgg16
sz_lr = 72
scale,bs = 2,64
# scale,bs = 4,32
sz_hr = sz_lr*scale
class MatchedFilesDataset(FilesDataset):
def __init__(self, fnames, y, transform, path):
self.y=y
assert(len(fnames)==len(y))
super().__init__(fnames, transform, path)
def get_y(self, i): return open_image(os.path.join(self.path, self.y[i]))
def get_c(self): return 0
aug_tfms = [RandomDihedral(tfm_y=TfmType.PIXEL)]
val_idxs = get_cv_idxs(len(fnames), val_pct=min(0.01/keep_pct, 0.1))
((val_x,trn_x),(val_y,trn_y)) = split_by_idx(val_idxs, np.array(fnames), np.array(fnames))
len(val_x),len(trn_x)
(12811, 1268356)
img_fn = PATH/'train'/'n01558993'/'n01558993_9684.JPEG'
tfms = tfms_from_model(arch, sz_lr, tfm_y=TfmType.PIXEL, aug_tfms=aug_tfms, sz_y=sz_hr)
datasets = ImageData.get_ds(MatchedFilesDataset, (trn_x,trn_y), (val_x,val_y), tfms, path=PATH_TRN)
md = ImageData(PATH, datasets, bs, num_workers=16, classes=None)
denorm = md.val_ds.denorm
def show_img(ims, idx, figsize=(5,5), normed=True, ax=None):
if ax is None: fig,ax = plt.subplots(figsize=figsize)
if normed: ims = denorm(ims)
else: ims = np.rollaxis(to_np(ims),1,4)
ax.imshow(np.clip(ims,0,1)[idx])
ax.axis('off')
x,y = next(iter(md.val_dl))
x.size(),y.size()
(torch.Size([32, 3, 72, 72]), torch.Size([32, 3, 288, 288]))
idx=1
fig,axes = plt.subplots(1, 2, figsize=(9,5))
show_img(x,idx, ax=axes[0])
show_img(y,idx, ax=axes[1])
batches = [next(iter(md.aug_dl)) for i in range(9)]
fig, axes = plt.subplots(3, 6, figsize=(18, 9))
for i,(x,y) in enumerate(batches):
show_img(x,idx, ax=axes.flat[i*2])
show_img(y,idx, ax=axes.flat[i*2+1])
def conv(ni, nf, kernel_size=3, actn=False):
layers = [nn.Conv2d(ni, nf, kernel_size, padding=kernel_size//2)]
if actn: layers.append(nn.ReLU(True))
return nn.Sequential(*layers)
class ResSequential(nn.Module):
def __init__(self, layers, res_scale=1.0):
super().__init__()
self.res_scale = res_scale
self.m = nn.Sequential(*layers)
def forward(self, x): return x + self.m(x) * self.res_scale
def res_block(nf):
return ResSequential(
[conv(nf, nf, actn=True), conv(nf, nf)],
0.1)
def upsample(ni, nf, scale):
layers = []
for i in range(int(math.log(scale,2))):
layers += [conv(ni, nf*4), nn.PixelShuffle(2)]
return nn.Sequential(*layers)
class SrResnet(nn.Module):
def __init__(self, nf, scale):
super().__init__()
features = [conv(3, 64)]
for i in range(8): features.append(res_block(64))
features += [conv(64,64), upsample(64, 64, scale),
nn.BatchNorm2d(64),
conv(64, 3)]
self.features = nn.Sequential(*features)
def forward(self, x): return self.features(x)
m = to_gpu(SrResnet(64, scale))
# if you have more than one GPU, list the GPU ids below
m = nn.DataParallel(m, [0])
#m = nn.DataParallel(m, [0,2])
learn = Learner(md, SingleModel(m), opt_fn=optim.Adam)
learn.crit = F.mse_loss
learn.lr_find(start_lr=1e-5, end_lr=10000)
learn.sched.plot()
A Jupyter Widget
31%|███▏ | 225/720 [00:24<00:53, 9.19it/s, loss=0.0482]
lr=2e-3
learn.fit(lr, 1, cycle_len=1, use_clr_beta=(40,10))
A Jupyter Widget
2%|▏ | 15/720 [00:02<01:52, 6.25it/s, loss=0.042] epoch trn_loss val_loss 0 0.007431 0.008192
[array([0.00819])]
x,y = next(iter(md.val_dl))
preds = learn.model(VV(x))
idx=4
show_img(y,idx,normed=False)
show_img(preds,idx,normed=False);
show_img(x,idx,normed=True);
x,y = next(iter(md.val_dl))
preds = learn.model(VV(x))
show_img(y,idx,normed=False)
show_img(preds,idx,normed=False);
show_img(x,idx);
learn.save('sr-samp0')
def icnr(x, scale=2, init=nn.init.kaiming_normal):
new_shape = [int(x.shape[0] / (scale ** 2))] + list(x.shape[1:])
subkernel = torch.zeros(new_shape)
subkernel = init(subkernel)
subkernel = subkernel.transpose(0, 1)
subkernel = subkernel.contiguous().view(subkernel.shape[0],
subkernel.shape[1], -1)
kernel = subkernel.repeat(1, 1, scale ** 2)
transposed_shape = [x.shape[1]] + [x.shape[0]] + list(x.shape[2:])
kernel = kernel.contiguous().view(transposed_shape)
kernel = kernel.transpose(0, 1)
return kernel
m_vgg = vgg16(True)
blocks = [i-1 for i,o in enumerate(children(m_vgg))
if isinstance(o,nn.MaxPool2d)]
blocks, [m_vgg[i] for i in blocks]
([5, 12, 22, 32, 42], [ReLU(inplace), ReLU(inplace), ReLU(inplace), ReLU(inplace), ReLU(inplace)])
vgg_layers = children(m_vgg)[:23]
m_vgg = nn.Sequential(*vgg_layers).cuda().eval()
set_trainable(m_vgg, False)
def flatten(x): return x.view(x.size(0), -1)
class SaveFeatures():
features=None
def __init__(self, m): self.hook = m.register_forward_hook(self.hook_fn)
def hook_fn(self, module, input, output): self.features = output
def remove(self): self.hook.remove()
class FeatureLoss(nn.Module):
def __init__(self, m, layer_ids, layer_wgts):
super().__init__()
self.m,self.wgts = m,layer_wgts
self.sfs = [SaveFeatures(m[i]) for i in layer_ids]
def forward(self, input, target, sum_layers=True):
self.m(VV(target.data))
res = [F.l1_loss(input,target)/100]
targ_feat = [V(o.features.data.clone()) for o in self.sfs]
self.m(input)
res += [F.l1_loss(flatten(inp.features),flatten(targ))*wgt
for inp,targ,wgt in zip(self.sfs, targ_feat, self.wgts)]
if sum_layers: res = sum(res)
return res
def close(self):
for o in self.sfs: o.remove()
m = SrResnet(64, scale)
conv_shuffle = m.features[10][0][0]
kernel = icnr(conv_shuffle.weight, scale=scale)
conv_shuffle.weight.data.copy_(kernel);
m = to_gpu(m)
learn = Learner(md, SingleModel(m), opt_fn=optim.Adam)
t = torch.load(learn.get_model_path('sr-samp0'), map_location=lambda storage, loc: storage)
learn.model.load_state_dict(t, strict=False)
learn.freeze_to(999)
for i in range(10,13): set_trainable(m.features[i], True)
conv_shuffle = m.features[10][0][0]
kernel = icnr(conv_shuffle.weight, scale=scale)
conv_shuffle.weight.data.copy_(kernel);
# if you have more than one GPU, list the GPU ids below
m = nn.DataParallel(m, None)
#m = nn.DataParallel(m, [0,2])
learn = Learner(md, SingleModel(m), opt_fn=optim.Adam)
learn.set_data(md)
learn.crit = FeatureLoss(m_vgg, blocks[:3], [0.2,0.7,0.1])
lr=6e-3
wd=1e-7
learn.lr_find(1e-4, 0.1, wds=wd, linear=True)
HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))
1%| | 15/1801 [00:06<12:55, 2.30it/s, loss=0.0965] 12%|█▏ | 220/1801 [01:16<09:08, 2.88it/s, loss=0.42]
learn.sched.plot(n_skip_end=1)
learn.fit(lr, 1, cycle_len=2, wds=wd, use_clr=(20,10))
HBox(children=(IntProgress(value=0, description='Epoch', max=2), HTML(value='')))
epoch trn_loss val_loss 0 0.04523 0.042932 1 0.043574 0.041242
[array([0.04124])]
learn.save('sr-samp0')
learn.save('sr-samp1')
learn.load('sr-samp1')
lr=3e-3
learn.fit(lr, 1, cycle_len=1, wds=wd, use_clr=(20,10))
HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))
epoch trn_loss val_loss 0 0.069054 0.06638
[array([0.06638])]
learn.save('sr-samp2')
learn.unfreeze()
learn.load('sr-samp2')
learn.fit(lr/3, 1, cycle_len=1, wds=wd, use_clr=(20,10))
HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))
epoch trn_loss val_loss 0 0.06042 0.057613
[array([0.05761])]
learn.save('sr1')
learn.sched.plot_loss()
def plot_ds_img(idx, ax=None, figsize=(7,7), normed=True):
if ax is None: fig,ax = plt.subplots(figsize=figsize)
im = md.val_ds[idx][0]
if normed: im = denorm(im)[0]
else: im = np.rollaxis(to_np(im),0,3)
ax.imshow(im)
ax.axis('off')
fig,axes=plt.subplots(6,6,figsize=(20,20))
# fix: IndexError: index 200 is out of bounds for axis 0 with size 194
offset = len(val_x)-len(axes.flat)
offset
for i,ax in enumerate(axes.flat): plot_ds_img(i+offset,ax=ax, normed=True)
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
x,y=md.val_ds[len(val_x)-1]
y=y[None]
learn.model.eval()
preds = learn.model(VV(x[None]))
x.shape,y.shape,preds.shape
((3, 72, 72), (1, 3, 288, 288), torch.Size([1, 3, 288, 288]))
learn.crit(preds, V(y), sum_layers=False)
[Variable containing: 1.00000e-03 * 1.1935 [torch.cuda.FloatTensor of size 1 (GPU 0)], Variable containing: 1.00000e-03 * 8.5054 [torch.cuda.FloatTensor of size 1 (GPU 0)], Variable containing: 1.00000e-02 * 3.4656 [torch.cuda.FloatTensor of size 1 (GPU 0)], Variable containing: 1.00000e-03 * 3.8243 [torch.cuda.FloatTensor of size 1 (GPU 0)]]
learn.crit.close()
_,axes=plt.subplots(1,2,figsize=(14,7))
show_img(x[None], 0, ax=axes[0])
show_img(preds,0, normed=True, ax=axes[1])