#|default_exp resnet
#|export
import pickle,gzip,math,os,time,shutil,torch,matplotlib as mpl,numpy as np,matplotlib.pyplot as plt
import fastcore.all as fc
from collections.abc import Mapping
from pathlib import Path
from operator import attrgetter,itemgetter
from functools import partial
from copy import copy
from contextlib import contextmanager
import torchvision.transforms.functional as TF,torch.nn.functional as F
from torch import tensor,nn,optim
from torch.utils.data import DataLoader,default_collate
from torch.nn import init
from torch.optim import lr_scheduler
from torcheval.metrics import MulticlassAccuracy
from datasets import load_dataset,load_dataset_builder
from miniai.datasets import *
from miniai.conv import *
from miniai.learner import *
from miniai.activations import *
from miniai.init import *
from miniai.sgd import *
from fastcore.test import test_close
torch.set_printoptions(precision=2, linewidth=140, sci_mode=False)
torch.manual_seed(1)
mpl.rcParams['image.cmap'] = 'gray'
import logging
logging.disable(logging.WARNING)
set_seed(42)
xl,yl = 'image','label'
name = "fashion_mnist"
bs = 1024
xmean,xstd = 0.28, 0.35
@inplace
def transformi(b): b[xl] = [(TF.to_tensor(o)-xmean)/xstd for o in b[xl]]
dsd = load_dataset(name)
tds = dsd.with_transform(transformi)
dls = DataLoaders.from_dd(tds, bs, num_workers=4)
0%| | 0/2 [00:00<?, ?it/s]
#|export
act_gr = partial(GeneralRelu, leak=0.1, sub=0.4)
metrics = MetricsCB(accuracy=MulticlassAccuracy())
astats = ActivationStats(fc.risinstance(GeneralRelu))
cbs = [DeviceCB(), metrics, ProgressCB(plot=True), astats]
iw = partial(init_weights, leaky=0.1)
def get_model(act=nn.ReLU, nfs=(8,16,32,64,128), norm=nn.BatchNorm2d):
layers = [conv(1, 8, stride=1, act=act, norm=norm)]
layers += [conv(nfs[i], nfs[i+1], act=act, norm=norm) for i in range(len(nfs)-1)]
return nn.Sequential(*layers, conv(nfs[-1], 10, act=None, norm=norm, bias=True), nn.Flatten()).to(def_device)
set_seed(42)
lr,epochs = 6e-2,5
model = get_model(act_gr, norm=nn.BatchNorm2d).apply(iw)
tmax = epochs * len(dls.train)
sched = partial(lr_scheduler.OneCycleLR, max_lr=lr, total_steps=tmax)
xtra = [BatchSchedCB(sched)]
learn = TrainLearner(model, dls, F.cross_entropy, lr=lr, cbs=cbs+xtra, opt_func=optim.AdamW)
learn.fit(epochs)
accuracy | loss | epoch | train |
---|---|---|---|
0.806 | 0.703 | 0 | train |
0.847 | 0.456 | 0 | eval |
0.884 | 0.333 | 1 | train |
0.856 | 0.415 | 1 | eval |
0.906 | 0.263 | 2 | train |
0.882 | 0.325 | 2 | eval |
0.923 | 0.215 | 3 | train |
0.910 | 0.251 | 3 | eval |
0.940 | 0.170 | 4 | train |
0.917 | 0.232 | 4 | eval |
The ResNet (residual network) was introduced in 2015 by Kaiming He et al in the article "Deep Residual Learning for Image Recognition". The key idea is using a skip connection to allow deeper networks to train successfully.
#|export
def _conv_block(ni, nf, stride, act=act_gr, norm=None, ks=3):
return nn.Sequential(conv(ni, nf, stride=1, act=act, norm=norm, ks=ks),
conv(nf, nf, stride=stride, act=None, norm=norm, ks=ks))
class ResBlock(nn.Module):
def __init__(self, ni, nf, stride=1, ks=3, act=act_gr, norm=None):
super().__init__()
self.convs = _conv_block(ni, nf, stride, act=act, ks=ks, norm=norm)
self.idconv = fc.noop if ni==nf else conv(ni, nf, ks=1, stride=1, act=None)
self.pool = fc.noop if stride==1 else nn.AvgPool2d(2, ceil_mode=True)
self.act = act()
def forward(self, x): return self.act(self.convs(x) + self.idconv(self.pool(x)))
Post-lesson update: Piotr Czapla noticed that we forgot to include norm=norm
in the call to _conv_block
above, so the resnets in the lesson didn't have batchnorm in the resblocks! After fixing this, we discovered that initializing the conv2
batchnorm weights to zero makes things worse in every model we tried, so we removed that. That init method was originally introduced to handle training extremely deep models (much deeper than we use here) -- it appears from this little test that it might be worse for less deep models.
def get_model(act=nn.ReLU, nfs=(8,16,32,64,128,256), norm=nn.BatchNorm2d):
layers = [ResBlock(1, 8, stride=1, act=act, norm=norm)]
layers += [ResBlock(nfs[i], nfs[i+1], act=act, norm=norm, stride=2) for i in range(len(nfs)-1)]
layers += [nn.Flatten(), nn.Linear(nfs[-1], 10, bias=False), nn.BatchNorm1d(10)]
return nn.Sequential(*layers).to(def_device)
def _print_shape(hook, mod, inp, outp): print(type(mod).__name__, inp[0].shape, outp.shape)
model = get_model()
learn = TrainLearner(model, dls, F.cross_entropy, cbs=[DeviceCB(), SingleBatchCB()])
with Hooks(model, _print_shape) as hooks: learn.fit(1, train=False)
ResBlock torch.Size([1024, 1, 28, 28]) torch.Size([1024, 8, 28, 28]) ResBlock torch.Size([1024, 8, 28, 28]) torch.Size([1024, 16, 14, 14]) ResBlock torch.Size([1024, 16, 14, 14]) torch.Size([1024, 32, 7, 7]) ResBlock torch.Size([1024, 32, 7, 7]) torch.Size([1024, 64, 4, 4]) ResBlock torch.Size([1024, 64, 4, 4]) torch.Size([1024, 128, 2, 2]) ResBlock torch.Size([1024, 128, 2, 2]) torch.Size([1024, 256, 1, 1]) Flatten torch.Size([1024, 256, 1, 1]) torch.Size([1024, 256]) Linear torch.Size([1024, 256]) torch.Size([1024, 10]) BatchNorm1d torch.Size([1024, 10]) torch.Size([1024, 10])
@fc.patch
def summary(self:Learner):
res = '|Module|Input|Output|Num params|\n|--|--|--|--|\n'
tot = 0
def _f(hook, mod, inp, outp):
nonlocal res,tot
nparms = sum(o.numel() for o in mod.parameters())
tot += nparms
res += f'|{type(mod).__name__}|{tuple(inp[0].shape)}|{tuple(outp.shape)}|{nparms}|\n'
with Hooks(self.model, _f) as hooks: self.fit(1, lr=1, train=False, cbs=SingleBatchCB())
print("Tot params: ", tot)
if fc.IN_NOTEBOOK:
from IPython.display import Markdown
return Markdown(res)
else: print(res)
TrainLearner(get_model(), dls, F.cross_entropy, cbs=DeviceCB()).summary()
Tot params: 1228908
Module | Input | Output | Num params |
---|---|---|---|
ResBlock | (1024, 1, 28, 28) | (1024, 8, 28, 28) | 712 |
ResBlock | (1024, 8, 28, 28) | (1024, 16, 14, 14) | 3696 |
ResBlock | (1024, 16, 14, 14) | (1024, 32, 7, 7) | 14560 |
ResBlock | (1024, 32, 7, 7) | (1024, 64, 4, 4) | 57792 |
ResBlock | (1024, 64, 4, 4) | (1024, 128, 2, 2) | 230272 |
ResBlock | (1024, 128, 2, 2) | (1024, 256, 1, 1) | 919296 |
Flatten | (1024, 256, 1, 1) | (1024, 256) | 0 |
Linear | (1024, 256) | (1024, 10) | 2560 |
BatchNorm1d | (1024, 10) | (1024, 10) | 20 |
model = get_model(act_gr, norm=nn.BatchNorm2d).apply(iw)
MomentumLearner(model, dls, F.cross_entropy, cbs=DeviceCB()).lr_find()
lr = 2e-2
tmax = epochs * len(dls.train)
sched = partial(lr_scheduler.OneCycleLR, max_lr=lr, total_steps=tmax)
xtra = [BatchSchedCB(sched)]
model = get_model(act_gr, norm=nn.BatchNorm2d).apply(iw)
learn = TrainLearner(model, dls, F.cross_entropy, lr=lr, cbs=cbs+xtra, opt_func=optim.AdamW)
learn.fit(epochs)
accuracy | loss | epoch | train |
---|---|---|---|
0.827 | 0.681 | 0 | train |
0.814 | 0.609 | 0 | eval |
0.894 | 0.352 | 1 | train |
0.887 | 0.326 | 1 | eval |
0.913 | 0.262 | 2 | train |
0.910 | 0.265 | 2 | eval |
0.934 | 0.196 | 3 | train |
0.922 | 0.227 | 3 | eval |
0.950 | 0.149 | 4 | train |
0.928 | 0.211 | 4 | eval |
import timm
from timm.models.resnet import BasicBlock, ResNet, Bottleneck
' '.join(timm.list_models('*resnet*'))
'cspresnet50 cspresnet50d cspresnet50w eca_resnet33ts ecaresnet26t ecaresnet50d ecaresnet50d_pruned ecaresnet50t ecaresnet101d ecaresnet101d_pruned ecaresnet200d ecaresnet269d ecaresnetlight ens_adv_inception_resnet_v2 gcresnet33ts gcresnet50t gluon_resnet18_v1b gluon_resnet34_v1b gluon_resnet50_v1b gluon_resnet50_v1c gluon_resnet50_v1d gluon_resnet50_v1s gluon_resnet101_v1b gluon_resnet101_v1c gluon_resnet101_v1d gluon_resnet101_v1s gluon_resnet152_v1b gluon_resnet152_v1c gluon_resnet152_v1d gluon_resnet152_v1s inception_resnet_v2 lambda_resnet26rpt_256 lambda_resnet26t lambda_resnet50ts legacy_seresnet18 legacy_seresnet34 legacy_seresnet50 legacy_seresnet101 legacy_seresnet152 nf_ecaresnet26 nf_ecaresnet50 nf_ecaresnet101 nf_resnet26 nf_resnet50 nf_resnet101 nf_seresnet26 nf_seresnet50 nf_seresnet101 resnet10t resnet14t resnet18 resnet18d resnet26 resnet26d resnet26t resnet32ts resnet33ts resnet34 resnet34d resnet50 resnet50_gn resnet50d resnet50t resnet51q resnet61q resnet101 resnet101d resnet152 resnet152d resnet200 resnet200d resnetaa50 resnetaa50d resnetaa101d resnetblur18 resnetblur50 resnetblur50d resnetblur101d resnetrs50 resnetrs101 resnetrs152 resnetrs200 resnetrs270 resnetrs350 resnetrs420 resnetv2_50 resnetv2_50d resnetv2_50d_evob resnetv2_50d_evos resnetv2_50d_frn resnetv2_50d_gn resnetv2_50t resnetv2_50x1_bit_distilled resnetv2_50x1_bitm resnetv2_50x1_bitm_in21k resnetv2_50x3_bitm resnetv2_50x3_bitm_in21k resnetv2_101 resnetv2_101d resnetv2_101x1_bitm resnetv2_101x1_bitm_in21k resnetv2_101x3_bitm resnetv2_101x3_bitm_in21k resnetv2_152 resnetv2_152d resnetv2_152x2_bit_teacher resnetv2_152x2_bit_teacher_384 resnetv2_152x2_bitm resnetv2_152x2_bitm_in21k resnetv2_152x4_bitm resnetv2_152x4_bitm_in21k seresnet18 seresnet33ts seresnet34 seresnet50 seresnet50t seresnet101 seresnet152 seresnet152d seresnet200d seresnet269d seresnetaa50d skresnet18 skresnet34 skresnet50 skresnet50d ssl_resnet18 ssl_resnet50 swsl_resnet18 swsl_resnet50 tresnet_l tresnet_l_448 tresnet_m tresnet_m_448 tresnet_m_miil_in21k tresnet_v2_l tresnet_xl tresnet_xl_448 tv_resnet34 tv_resnet50 tv_resnet101 tv_resnet152 vit_base_resnet26d_224 vit_base_resnet50_224_in21k vit_base_resnet50_384 vit_base_resnet50d_224 vit_small_resnet26d_224 vit_small_resnet50d_s16_224 wide_resnet50_2 wide_resnet101_2'
resnet18: block=BasicBlock, layers=[2, 2, 2, 2]
resnet18d: block=BasicBlock, layers=[2, 2, 2, 2], stem_width=32, stem_type='deep', avg_down=True
resnet10t: block=BasicBlock, layers=[1, 1, 1, 1], stem_width=32, stem_type='deep_tiered', avg_down=True
model = timm.create_model('resnet18d', in_chans=1, num_classes=10)
# model = ResNet(in_chans=1, block=BasicBlock, layers=[2,2,2,2], stem_width=32, avg_down=True)
lr = 2e-2
tmax = epochs * len(dls.train)
sched = partial(lr_scheduler.OneCycleLR, max_lr=lr, total_steps=tmax)
xtra = [BatchSchedCB(sched)]
learn = TrainLearner(model, dls, F.cross_entropy, lr=lr, cbs=cbs+xtra, opt_func=optim.AdamW)
learn.fit(epochs)
accuracy | loss | epoch | train |
---|---|---|---|
0.781 | 0.633 | 0 | train |
0.664 | 1.316 | 0 | eval |
0.878 | 0.329 | 1 | train |
0.870 | 0.362 | 1 | eval |
0.905 | 0.255 | 2 | train |
0.889 | 0.307 | 2 | eval |
0.926 | 0.197 | 3 | train |
0.911 | 0.244 | 3 | eval |
0.945 | 0.150 | 4 | train |
0.920 | 0.223 | 4 | eval |
import nbdev; nbdev.nbdev_export()