Important: This notebook will only work with fastai-0.7.x. Do not try to run any fastai-1.x code from this path in the repository because it will load fastai-0.7.x
%reload_ext autoreload
%autoreload 2
%matplotlib inline
from fastai.imports import *
from fastai.transforms import *
from fastai.conv_learner import *
from fastai.model import *
from fastai.dataset import *
from fastai.sgdr import *
import skimage
PATH = "data/dogscats/"
sz = 224
arch = resnet34
bs = 64
m = arch(True)
m
ResNet( (conv1): Conv2d (3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) (relu): ReLU(inplace) (maxpool): MaxPool2d(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), dilation=(1, 1)) (layer1): Sequential( (0): BasicBlock( (conv1): Conv2d (64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) (relu): ReLU(inplace) (conv2): Conv2d (64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) ) (1): BasicBlock( (conv1): Conv2d (64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) (relu): ReLU(inplace) (conv2): Conv2d (64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) ) (2): BasicBlock( (conv1): Conv2d (64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) (relu): ReLU(inplace) (conv2): Conv2d (64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) ) ) (layer2): Sequential( (0): BasicBlock( (conv1): Conv2d (64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True) (relu): ReLU(inplace) (conv2): Conv2d (128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True) (downsample): Sequential( (0): Conv2d (64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False) (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True) ) ) (1): BasicBlock( (conv1): Conv2d (128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True) (relu): ReLU(inplace) (conv2): Conv2d (128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True) ) (2): BasicBlock( (conv1): Conv2d (128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True) (relu): ReLU(inplace) (conv2): Conv2d (128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True) ) (3): BasicBlock( (conv1): Conv2d (128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True) (relu): ReLU(inplace) (conv2): Conv2d (128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True) ) ) (layer3): Sequential( (0): BasicBlock( (conv1): Conv2d (128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True) (relu): ReLU(inplace) (conv2): Conv2d (256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True) (downsample): Sequential( (0): Conv2d (128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False) (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True) ) ) (1): BasicBlock( (conv1): Conv2d (256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True) (relu): ReLU(inplace) (conv2): Conv2d (256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True) ) (2): BasicBlock( (conv1): Conv2d (256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True) (relu): ReLU(inplace) (conv2): Conv2d (256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True) ) (3): BasicBlock( (conv1): Conv2d (256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True) (relu): ReLU(inplace) (conv2): Conv2d (256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True) ) (4): BasicBlock( (conv1): Conv2d (256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True) (relu): ReLU(inplace) (conv2): Conv2d (256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True) ) (5): BasicBlock( (conv1): Conv2d (256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True) (relu): ReLU(inplace) (conv2): Conv2d (256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True) ) ) (layer4): Sequential( (0): BasicBlock( (conv1): Conv2d (256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) (relu): ReLU(inplace) (conv2): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) (downsample): Sequential( (0): Conv2d (256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False) (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) ) ) (1): BasicBlock( (conv1): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) (relu): ReLU(inplace) (conv2): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) ) (2): BasicBlock( (conv1): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) (relu): ReLU(inplace) (conv2): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) ) ) (avgpool): AvgPool2d(kernel_size=7, stride=7, padding=0, ceil_mode=False, count_include_pad=True) (fc): Linear(in_features=512, out_features=1000) )
m = nn.Sequential(*children(m)[:-2],
nn.Conv2d(512, 2, 3, padding=1),
nn.AdaptiveAvgPool2d(1), Flatten(),
nn.LogSoftmax())
tfms = tfms_from_model(arch, sz, aug_tfms=transforms_side_on, max_zoom=1.1)
data = ImageClassifierData.from_paths(PATH, tfms=tfms, bs=bs)
learn = ConvLearner.from_model_data(m, data)
learn.freeze_to(-4)
m[-1].trainable
True
m[-4].trainable
True
learn.fit(0.01, 1)
A Jupyter Widget
[ 0. 0.12421 0.1227 0.98633]
learn.fit(0.01, 1, cycle_len=1)
A Jupyter Widget
[ 0. 0.0725 0.07043 0.99072]
class SaveFeatures():
features=None
def __init__(self, m): self.hook = m.register_forward_hook(self.hook_fn)
def hook_fn(self, module, input, output): self.features = output
def remove(self): self.hook.remove()
x,y = next(iter(data.val_dl))
x,y = x[None,1], y[None,1]
vx = Variable(x.cuda(), requires_grad=True)
dx = data.val_ds.denorm(x)[0]
plt.imshow(dx);
sfs = [SaveFeatures(o) for o in [m[-7], m[-6], m[-5], m[-4]]]
%time py = m(Variable(x.cuda()))
Wall time: 20.1 ms
for o in sfs: o.remove()
[o.features.size() for o in sfs]
[torch.Size([64, 64, 56, 56]), torch.Size([64, 128, 28, 28]), torch.Size([64, 256, 14, 14]), torch.Size([64, 512, 7, 7])]
py = np.exp(to_np(py)[0]); py
array([ 1., 0.], dtype=float32)
feat = np.maximum(0,to_np(sfs[3].features[0]))
feat.shape
(2, 7, 7)
f2=np.dot(np.rollaxis(feat,0,3), py)
f2-=f2.min()
f2/=f2.max()
f2
array([[ 0.06796, 0.09529, 0.04571, 0.02226, 0.01783, 0.02682, 0. ], [ 0.18074, 0.25366, 0.21349, 0.17134, 0.12896, 0.07998, 0.02269], [ 0.32468, 0.53808, 0.58196, 0.51675, 0.36953, 0.20348, 0.06563], [ 0.4753 , 0.82032, 0.93738, 0.84066, 0.58015, 0.30568, 0.09089], [ 0.53621, 0.90491, 1. , 0.87108, 0.60574, 0.33723, 0.11391], [ 0.44609, 0.70038, 0.72927, 0.60194, 0.42627, 0.26258, 0.10247], [ 0.21954, 0.3172 , 0.30497, 0.21987, 0.15255, 0.08741, 0.00709]], dtype=float32)
plt.imshow(dx)
plt.imshow(skimage.transform.resize(f2, dx.shape), alpha=0.5, cmap='hot');
learn.unfreeze()
learn.bn_freeze(True)
# 12 layer groups call for 12 lrs
lr=np.array([[1e-6]*4,[1e-4]*4,[1e-2]*4]).flatten()
learn.fit(lr, 2, cycle_len=1)
A Jupyter Widget
[ 0. 0.02049 0.01577 0.99405] [ 1. 0.01648 0.0137 0.99554]
log_preds,y = learn.TTA()
preds = np.mean(np.exp(log_preds),0)
accuracy_np(preds,y)
0.996
learn.fit(lr, 2, cycle_len=1)
A Jupyter Widget
[ 0. 0.01243 0.01535 0.99554] [ 1. 0.0128 0.01507 0.99554]
log_preds,y = learn.TTA()
preds = np.mean(np.exp(log_preds),0)
accuracy_np(preds,y)
0.99650000000000005