#|default_exp augment
#|export
import torch,random
import fastcore.all as fc
from torch import nn
from torch.nn import init
from miniai.datasets import *
from miniai.conv import *
from miniai.learner import *
from miniai.activations import *
from miniai.init import *
from miniai.sgd import *
from miniai.resnet import *
import pickle,gzip,math,os,time,shutil
import matplotlib as mpl,numpy as np,matplotlib.pyplot as plt
from collections.abc import Mapping
from pathlib import Path
from operator import attrgetter,itemgetter
from functools import partial
from copy import copy
from contextlib import contextmanager
import torchvision.transforms.functional as TF,torch.nn.functional as F
from torch import tensor,optim
from torch.utils.data import DataLoader,default_collate
from torch.optim import lr_scheduler
from torcheval.metrics import MulticlassAccuracy
from datasets import load_dataset,load_dataset_builder
from fastcore.test import test_close
from torch import distributions
torch.set_printoptions(precision=2, linewidth=140, sci_mode=False)
torch.manual_seed(1)
mpl.rcParams['image.cmap'] = 'gray_r'
import logging
logging.disable(logging.WARNING)
set_seed(42)
if fc.defaults.cpus>8: fc.defaults.cpus=8
xl,yl = 'image','label'
name = "fashion_mnist"
bs = 1024
xmean,xstd = 0.28, 0.35
@inplace
def transformi(b): b[xl] = [(TF.to_tensor(o)-xmean)/xstd for o in b[xl]]
dsd = load_dataset(name)
tds = dsd.with_transform(transformi)
dls = DataLoaders.from_dd(tds, bs, num_workers=fc.defaults.cpus)
0%| | 0/2 [00:00<?, ?it/s]
metrics = MetricsCB(accuracy=MulticlassAccuracy())
astats = ActivationStats(fc.risinstance(GeneralRelu))
cbs = [DeviceCB(), metrics, ProgressCB(plot=True), astats]
act_gr = partial(GeneralRelu, leak=0.1, sub=0.4)
iw = partial(init_weights, leaky=0.1)
set_seed(42)
lr,epochs = 6e-2,5
def get_model(act=nn.ReLU, nfs=(16,32,64,128,256,512), norm=nn.BatchNorm2d):
layers = [ResBlock(1, 16, ks=5, stride=1, act=act, norm=norm)]
layers += [ResBlock(nfs[i], nfs[i+1], act=act, norm=norm, stride=2) for i in range(len(nfs)-1)]
layers += [nn.Flatten(), nn.Linear(nfs[-1], 10, bias=False), nn.BatchNorm1d(10)]
return nn.Sequential(*layers)
lr = 1e-2
tmax = epochs * len(dls.train)
sched = partial(lr_scheduler.OneCycleLR, max_lr=lr, total_steps=tmax)
xtra = [BatchSchedCB(sched)]
model = get_model(act_gr, norm=nn.BatchNorm2d).apply(iw)
learn = TrainLearner(model, dls, F.cross_entropy, lr=lr, cbs=cbs+xtra, opt_func=optim.AdamW)
learn.fit(epochs)
accuracy | loss | epoch | train |
---|---|---|---|
0.823 | 0.707 | 0 | train |
0.844 | 0.537 | 0 | eval |
0.896 | 0.386 | 1 | train |
0.889 | 0.332 | 1 | eval |
0.918 | 0.270 | 2 | train |
0.914 | 0.280 | 2 | eval |
0.941 | 0.199 | 3 | train |
0.928 | 0.242 | 3 | eval |
0.961 | 0.145 | 4 | train |
0.934 | 0.219 | 4 | eval |
class GlobalAvgPool(nn.Module):
def forward(self, x): return x.mean((-2,-1))
def get_model2(act=nn.ReLU, nfs=(16,32,64,128,256), norm=nn.BatchNorm2d):
layers = [ResBlock(1, 16, ks=5, stride=1, act=act, norm=norm)]
layers += [ResBlock(nfs[i], nfs[i+1], act=act, norm=norm, stride=2) for i in range(len(nfs)-1)]
layers += [ResBlock(256, 512, act=act, norm=norm), GlobalAvgPool()]
layers += [nn.Linear(512, 10, bias=False), nn.BatchNorm1d(10)]
return nn.Sequential(*layers)
#|export
def _flops(x, h, w):
if x.dim()<3: return x.numel()
if x.dim()==4: return x.numel()*h*w
@fc.patch
def summary(self:Learner):
res = '|Module|Input|Output|Num params|MFLOPS|\n|--|--|--|--|--|\n'
totp,totf = 0,0
def _f(hook, mod, inp, outp):
nonlocal res,totp,totf
nparms = sum(o.numel() for o in mod.parameters())
totp += nparms
*_,h,w = outp.shape
flops = sum(_flops(o, h, w) for o in mod.parameters())/1e6
totf += flops
res += f'|{type(mod).__name__}|{tuple(inp[0].shape)}|{tuple(outp.shape)}|{nparms}|{flops:.1f}|\n'
with Hooks(self.model, _f) as hooks: self.fit(1, lr=1, cbs=SingleBatchCB())
print(f"Tot params: {totp}; MFLOPS: {totf:.1f}")
if fc.IN_NOTEBOOK:
from IPython.display import Markdown
return Markdown(res)
else: print(res)
TrainLearner(get_model2(), dls, F.cross_entropy, lr=lr, cbs=[DeviceCB()]).summary()
Tot params: 4907588; MFLOPS: 33.0
Module | Input | Output | Num params | MFLOPS |
---|---|---|---|---|
ResBlock | (1024, 1, 28, 28) | (1024, 16, 28, 28) | 6928 | 5.3 |
ResBlock | (1024, 16, 28, 28) | (1024, 32, 14, 14) | 14560 | 2.8 |
ResBlock | (1024, 32, 14, 14) | (1024, 64, 7, 7) | 57792 | 2.8 |
ResBlock | (1024, 64, 7, 7) | (1024, 128, 4, 4) | 230272 | 3.7 |
ResBlock | (1024, 128, 4, 4) | (1024, 256, 2, 2) | 919296 | 3.7 |
ResBlock | (1024, 256, 2, 2) | (1024, 512, 2, 2) | 3673600 | 14.7 |
GlobalAvgPool | (1024, 512, 2, 2) | (1024, 512) | 0 | 0.0 |
Linear | (1024, 512) | (1024, 10) | 5120 | 0.0 |
BatchNorm1d | (1024, 10) | (1024, 10) | 20 | 0.0 |
set_seed(42)
model = get_model2(act_gr, norm=nn.BatchNorm2d).apply(iw)
learn = TrainLearner(model, dls, F.cross_entropy, lr=lr, cbs=cbs+xtra, opt_func=optim.AdamW)
learn.fit(epochs)
accuracy | loss | epoch | train |
---|---|---|---|
0.819 | 0.724 | 0 | train |
0.861 | 0.448 | 0 | eval |
0.896 | 0.387 | 1 | train |
0.869 | 0.388 | 1 | eval |
0.918 | 0.274 | 2 | train |
0.913 | 0.269 | 2 | eval |
0.940 | 0.202 | 3 | train |
0.922 | 0.248 | 3 | eval |
0.959 | 0.150 | 4 | train |
0.929 | 0.223 | 4 | eval |
def get_model3(act=nn.ReLU, nfs=(16,32,64,128,256), norm=nn.BatchNorm2d):
layers = [ResBlock(1, 16, ks=5, stride=1, act=act, norm=norm)]
layers += [ResBlock(nfs[i], nfs[i+1], act=act, norm=norm, stride=2) for i in range(len(nfs)-1)]
layers += [GlobalAvgPool(), nn.Linear(256, 10, bias=False), nn.BatchNorm1d(10)]
return nn.Sequential(*layers)
TrainLearner(get_model3(), dls, F.cross_entropy, lr=lr, cbs=[DeviceCB()]).summary()
Tot params: 1231428; MFLOPS: 18.3
Module | Input | Output | Num params | MFLOPS |
---|---|---|---|---|
ResBlock | (1024, 1, 28, 28) | (1024, 16, 28, 28) | 6928 | 5.3 |
ResBlock | (1024, 16, 28, 28) | (1024, 32, 14, 14) | 14560 | 2.8 |
ResBlock | (1024, 32, 14, 14) | (1024, 64, 7, 7) | 57792 | 2.8 |
ResBlock | (1024, 64, 7, 7) | (1024, 128, 4, 4) | 230272 | 3.7 |
ResBlock | (1024, 128, 4, 4) | (1024, 256, 2, 2) | 919296 | 3.7 |
GlobalAvgPool | (1024, 256, 2, 2) | (1024, 256) | 0 | 0.0 |
Linear | (1024, 256) | (1024, 10) | 2560 | 0.0 |
BatchNorm1d | (1024, 10) | (1024, 10) | 20 | 0.0 |
[o.shape for o in get_model3()[0].parameters()]
[torch.Size([16, 1, 5, 5]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16, 16, 5, 5]), torch.Size([16]), torch.Size([16]), torch.Size([16]), torch.Size([16, 1, 1, 1]), torch.Size([16])]
set_seed(42)
model = get_model3(act_gr, norm=nn.BatchNorm2d).apply(iw)
learn = TrainLearner(model, dls, F.cross_entropy, lr=lr, cbs=cbs+xtra, opt_func=optim.AdamW)
learn.fit(epochs)
accuracy | loss | epoch | train |
---|---|---|---|
0.810 | 0.758 | 0 | train |
0.857 | 0.489 | 0 | eval |
0.893 | 0.402 | 1 | train |
0.890 | 0.350 | 1 | eval |
0.917 | 0.282 | 2 | train |
0.913 | 0.283 | 2 | eval |
0.937 | 0.212 | 3 | train |
0.922 | 0.260 | 3 | eval |
0.957 | 0.158 | 4 | train |
0.930 | 0.230 | 4 | eval |
def get_model4(act=nn.ReLU, nfs=(16,32,64,128,256), norm=nn.BatchNorm2d):
layers = [conv(1, 16, ks=5, stride=1, act=act, norm=norm)]
layers += [ResBlock(nfs[i], nfs[i+1], act=act, norm=norm, stride=2) for i in range(len(nfs)-1)]
layers += [GlobalAvgPool(), nn.Linear(256, 10, bias=False), nn.BatchNorm1d(10)]
return nn.Sequential(*layers)
[o.shape for o in get_model4()[0].parameters()]
[torch.Size([16, 1, 5, 5]), torch.Size([16]), torch.Size([16]), torch.Size([16])]
TrainLearner(get_model4(), dls, F.cross_entropy, lr=lr, cbs=[DeviceCB()]).summary()
Tot params: 1224948; MFLOPS: 13.3
Module | Input | Output | Num params | MFLOPS |
---|---|---|---|---|
Sequential | (1024, 1, 28, 28) | (1024, 16, 28, 28) | 448 | 0.3 |
ResBlock | (1024, 16, 28, 28) | (1024, 32, 14, 14) | 14560 | 2.8 |
ResBlock | (1024, 32, 14, 14) | (1024, 64, 7, 7) | 57792 | 2.8 |
ResBlock | (1024, 64, 7, 7) | (1024, 128, 4, 4) | 230272 | 3.7 |
ResBlock | (1024, 128, 4, 4) | (1024, 256, 2, 2) | 919296 | 3.7 |
GlobalAvgPool | (1024, 256, 2, 2) | (1024, 256) | 0 | 0.0 |
Linear | (1024, 256) | (1024, 10) | 2560 | 0.0 |
BatchNorm1d | (1024, 10) | (1024, 10) | 20 | 0.0 |
set_seed(42)
model = get_model4(act_gr, norm=nn.BatchNorm2d).apply(iw)
learn = TrainLearner(model, dls, F.cross_entropy, lr=lr, cbs=cbs+xtra, opt_func=optim.AdamW)
learn.fit(epochs)
accuracy | loss | epoch | train |
---|---|---|---|
0.804 | 0.784 | 0 | train |
0.855 | 0.506 | 0 | eval |
0.895 | 0.401 | 1 | train |
0.888 | 0.356 | 1 | eval |
0.916 | 0.283 | 2 | train |
0.912 | 0.277 | 2 | eval |
0.937 | 0.215 | 3 | train |
0.917 | 0.280 | 3 | eval |
0.956 | 0.161 | 4 | train |
0.928 | 0.236 | 4 | eval |
CPU times: user 29min 44s, sys: 1min 38s, total: 31min 22s Wall time: 1min 4s
After 20 epochs without augmentation:
{'accuracy': '0.999', 'loss': '0.012', 'epoch': 19, 'train': True}
{'accuracy': '0.924', 'loss': '0.284', 'epoch': 19, 'train': False}
With batchnorm, weight decay doesn't really regularize.
from torchvision import transforms
def tfm_batch(b, tfm_x=fc.noop, tfm_y = fc.noop): return tfm_x(b[0]),tfm_y(b[1])
tfms = nn.Sequential(transforms.RandomCrop(28, padding=4),
transforms.RandomHorizontalFlip())
augcb = BatchTransformCB(partial(tfm_batch, tfm_x=tfms), on_val=False)
model = get_model()
learn = TrainLearner(model, dls, F.cross_entropy, lr=lr, cbs=[SingleBatchCB(), augcb])
learn.fit(1)
xb,yb = learn.batch
show_images(xb[:16], imsize=1.5)
#| export
@fc.patch
@fc.delegates(show_images)
def show_image_batch(self:Learner, max_n=9, cbs=None, **kwargs):
self.fit(1, cbs=[SingleBatchCB()]+fc.L(cbs))
show_images(self.batch[0][:max_n], **kwargs)
learn.show_image_batch(max_n=16, imsize=(1.5))
tfms = nn.Sequential(transforms.RandomCrop(28, padding=1),
transforms.RandomHorizontalFlip())
augcb = BatchTransformCB(partial(tfm_batch, tfm_x=tfms), on_val=False)
set_seed(42)
epochs = 20
lr = 1e-2
tmax = epochs * len(dls.train)
sched = partial(lr_scheduler.OneCycleLR, max_lr=lr, total_steps=tmax)
xtra = [BatchSchedCB(sched), augcb]
model = get_model(act_gr, norm=nn.BatchNorm2d).apply(iw)
learn = TrainLearner(model, dls, F.cross_entropy, lr=lr, cbs=cbs+xtra, opt_func=optim.AdamW)
learn.fit(epochs)
accuracy | loss | epoch | train |
---|---|---|---|
0.762 | 0.890 | 0 | train |
0.805 | 0.633 | 0 | eval |
0.857 | 0.602 | 1 | train |
0.854 | 0.499 | 1 | eval |
0.877 | 0.479 | 2 | train |
0.863 | 0.422 | 2 | eval |
0.885 | 0.399 | 3 | train |
0.882 | 0.356 | 3 | eval |
0.896 | 0.332 | 4 | train |
0.892 | 0.329 | 4 | eval |
0.904 | 0.292 | 5 | train |
0.898 | 0.296 | 5 | eval |
0.911 | 0.266 | 6 | train |
0.883 | 0.333 | 6 | eval |
0.919 | 0.237 | 7 | train |
0.897 | 0.303 | 7 | eval |
0.924 | 0.223 | 8 | train |
0.877 | 0.329 | 8 | eval |
0.930 | 0.204 | 9 | train |
0.912 | 0.254 | 9 | eval |
0.936 | 0.188 | 10 | train |
0.910 | 0.266 | 10 | eval |
0.937 | 0.182 | 11 | train |
0.926 | 0.216 | 11 | eval |
0.942 | 0.165 | 12 | train |
0.928 | 0.200 | 12 | eval |
0.949 | 0.149 | 13 | train |
0.923 | 0.218 | 13 | eval |
0.954 | 0.136 | 14 | train |
0.935 | 0.195 | 14 | eval |
0.956 | 0.126 | 15 | train |
0.937 | 0.183 | 15 | eval |
0.962 | 0.111 | 16 | train |
0.939 | 0.180 | 16 | eval |
0.966 | 0.103 | 17 | train |
0.940 | 0.185 | 17 | eval |
0.969 | 0.093 | 18 | train |
0.942 | 0.177 | 18 | eval |
0.972 | 0.089 | 19 | train |
0.943 | 0.173 | 19 | eval |
A custom collation function could let you do per-item transformations.
mdl_path = Path('models')
mdl_path.mkdir(exist_ok=True)
torch.save(learn.model, mdl_path/'data_aug.pkl')
#| export
class CapturePreds(Callback):
def before_fit(self, learn): self.all_inps,self.all_preds,self.all_targs = [],[],[]
def after_batch(self, learn):
self.all_inps. append(to_cpu(learn.batch[0]))
self.all_preds.append(to_cpu(learn.preds))
self.all_targs.append(to_cpu(learn.batch[1]))
def after_fit(self, learn):
self.all_preds,self.all_targs,self.all_inps = map(torch.cat, [self.all_preds,self.all_targs,self.all_inps])
#| export
@fc.patch
def capture_preds(self: Learner, cbs=None, inps=False):
cp = CapturePreds()
self.fit(1, train=False, cbs=[cp]+fc.L(cbs))
res = cp.all_preds,cp.all_targs
if inps: res = res+(cp.all_inps,)
return res
ap1, at = learn.capture_preds()
accuracy | loss | epoch | train |
---|---|---|---|
0.943 | 0.173 | 0 | eval |
ttacb = BatchTransformCB(partial(tfm_batch, tfm_x=TF.hflip), on_val=True)
ap2, at = learn.capture_preds(cbs=[ttacb])
accuracy | loss | epoch | train |
---|---|---|---|
0.943 | 0.173 | 0 | eval |
ap1.shape,ap2.shape,at.shape
(torch.Size([10000, 10]), torch.Size([10000, 10]), torch.Size([10000]))
ap = torch.stack([ap1,ap2]).mean(0).argmax(1)
round((ap==at).float().mean().item(), 3)
0.945
xb,_ = next(iter(dls.train))
xbt = xb[:16]
xm,xs = xbt.mean(),xbt.std()
xbt.min(), xbt.max()
(tensor(-3.19), tensor(2.64))
pct = 0.2
szx = int(pct*xbt.shape[-2])
szy = int(pct*xbt.shape[-1])
stx = int(random.random()*(1-pct)*xbt.shape[-2])
sty = int(random.random()*(1-pct)*xbt.shape[-1])
stx,sty,szx,szy
(19, 1, 5, 5)
init.normal_(xbt[:,:,stx:stx+szx,sty:sty+szy], mean=xm, std=xs);
show_images(xbt, imsize=1.5)
xbt.min(), xbt.max()
(tensor(-3.19), tensor(2.64))
#|export
def _rand_erase1(x, pct, xm, xs, mn, mx):
szx = int(pct*x.shape[-2])
szy = int(pct*x.shape[-1])
stx = int(random.random()*(1-pct)*x.shape[-2])
sty = int(random.random()*(1-pct)*x.shape[-1])
init.normal_(x[:,:,stx:stx+szx,sty:sty+szy], mean=xm, std=xs)
x.clamp_(mn, mx)
xb,_ = next(iter(dls.train))
xbt = xb[:16]
_rand_erase1(xbt, 0.2, xbt.mean(), xbt.std(), xbt.min(), xbt.max())
show_images(xbt, imsize=1.5)
xbt.mean(),xbt.std(),xbt.min(), xbt.max()
(tensor(-0.01), tensor(1.02), tensor(-0.80), tensor(2.06))
#|export
def rand_erase(x, pct=0.2, max_num = 4):
xm,xs,mn,mx = x.mean(),x.std(),x.min(),x.max()
num = random.randint(0, max_num)
for i in range(num): _rand_erase1(x, pct, xm, xs, mn, mx)
# print(num)
return x
xb,_ = next(iter(dls.train))
xbt = xb[:16]
rand_erase(xbt, 0.2, 4)
show_images(xbt, imsize=1.5)
#|export
class RandErase(nn.Module):
def __init__(self, pct=0.2, max_num=4):
super().__init__()
self.pct,self.max_num = pct,max_num
def forward(self, x): return rand_erase(x, self.pct, self.max_num)
tfms = nn.Sequential(transforms.RandomCrop(28, padding=1),
transforms.RandomHorizontalFlip(),
RandErase())
augcb = BatchTransformCB(partial(tfm_batch, tfm_x=tfms), on_val=False)
model = get_model()
learn = TrainLearner(model, dls, F.cross_entropy, lr=lr, cbs=[DeviceCB(), SingleBatchCB(), augcb])
learn.fit(1)
xb,yb = learn.batch
show_images(xb[:16], imsize=1.5)
epochs = 50
lr = 2e-2
tmax = epochs * len(dls.train)
sched = partial(lr_scheduler.OneCycleLR, max_lr=lr, total_steps=tmax)
xtra = [BatchSchedCB(sched), augcb]
model = get_model(act_gr, norm=nn.BatchNorm2d).apply(iw)
learn = TrainLearner(model, dls, F.cross_entropy, lr=lr, cbs=cbs+xtra, opt_func=optim.AdamW)
learn.fit(epochs)
accuracy | loss | epoch | train |
---|---|---|---|
0.764 | 0.875 | 0 | train |
0.813 | 0.622 | 0 | eval |
0.842 | 0.625 | 1 | train |
0.850 | 0.552 | 1 | eval |
0.861 | 0.537 | 2 | train |
0.864 | 0.446 | 2 | eval |
0.869 | 0.468 | 3 | train |
0.855 | 0.461 | 3 | eval |
0.873 | 0.422 | 4 | train |
0.861 | 0.435 | 4 | eval |
0.880 | 0.377 | 5 | train |
0.854 | 0.416 | 5 | eval |
0.884 | 0.350 | 6 | train |
0.859 | 0.427 | 6 | eval |
0.887 | 0.333 | 7 | train |
0.842 | 0.445 | 7 | eval |
0.897 | 0.296 | 8 | train |
0.889 | 0.302 | 8 | eval |
0.894 | 0.297 | 9 | train |
0.861 | 0.391 | 9 | eval |
0.894 | 0.294 | 10 | train |
0.847 | 0.439 | 10 | eval |
0.900 | 0.278 | 11 | train |
0.845 | 0.429 | 11 | eval |
0.898 | 0.281 | 12 | train |
0.897 | 0.285 | 12 | eval |
0.908 | 0.254 | 13 | train |
0.910 | 0.259 | 13 | eval |
0.910 | 0.249 | 14 | train |
0.891 | 0.293 | 14 | eval |
0.912 | 0.238 | 15 | train |
0.911 | 0.245 | 15 | eval |
0.916 | 0.230 | 16 | train |
0.910 | 0.242 | 16 | eval |
0.918 | 0.220 | 17 | train |
0.911 | 0.239 | 17 | eval |
0.921 | 0.217 | 18 | train |
0.883 | 0.315 | 18 | eval |
0.924 | 0.208 | 19 | train |
0.917 | 0.236 | 19 | eval |
0.925 | 0.204 | 20 | train |
0.919 | 0.230 | 20 | eval |
0.928 | 0.196 | 21 | train |
0.908 | 0.257 | 21 | eval |
0.932 | 0.188 | 22 | train |
0.921 | 0.217 | 22 | eval |
0.932 | 0.187 | 23 | train |
0.924 | 0.210 | 23 | eval |
0.931 | 0.186 | 24 | train |
0.911 | 0.243 | 24 | eval |
0.934 | 0.178 | 25 | train |
0.902 | 0.270 | 25 | eval |
0.937 | 0.171 | 26 | train |
0.927 | 0.215 | 26 | eval |
0.938 | 0.169 | 27 | train |
0.901 | 0.261 | 27 | eval |
0.939 | 0.167 | 28 | train |
0.915 | 0.238 | 28 | eval |
0.943 | 0.154 | 29 | train |
0.936 | 0.187 | 29 | eval |
0.944 | 0.151 | 30 | train |
0.937 | 0.183 | 30 | eval |
0.947 | 0.145 | 31 | train |
0.928 | 0.212 | 31 | eval |
0.949 | 0.139 | 32 | train |
0.925 | 0.215 | 32 | eval |
0.950 | 0.136 | 33 | train |
0.926 | 0.209 | 33 | eval |
0.951 | 0.137 | 34 | train |
0.936 | 0.183 | 34 | eval |
0.953 | 0.128 | 35 | train |
0.941 | 0.169 | 35 | eval |
0.956 | 0.121 | 36 | train |
0.941 | 0.177 | 36 | eval |
0.958 | 0.116 | 37 | train |
0.936 | 0.191 | 37 | eval |
0.959 | 0.109 | 38 | train |
0.944 | 0.167 | 38 | eval |
0.961 | 0.109 | 39 | train |
0.944 | 0.164 | 39 | eval |
0.961 | 0.104 | 40 | train |
0.941 | 0.177 | 40 | eval |
0.964 | 0.099 | 41 | train |
0.945 | 0.164 | 41 | eval |
0.967 | 0.093 | 42 | train |
0.945 | 0.165 | 42 | eval |
0.968 | 0.088 | 43 | train |
0.947 | 0.168 | 43 | eval |
0.971 | 0.080 | 44 | train |
0.949 | 0.157 | 44 | eval |
0.970 | 0.084 | 45 | train |
0.949 | 0.162 | 45 | eval |
0.974 | 0.075 | 46 | train |
0.948 | 0.161 | 46 | eval |
0.975 | 0.072 | 47 | train |
0.948 | 0.164 | 47 | eval |
0.976 | 0.068 | 48 | train |
0.949 | 0.162 | 48 | eval |
0.976 | 0.070 | 49 | train |
0.949 | 0.163 | 49 | eval |
xb,_ = next(iter(dls.train))
xbt = xb[:16]
szx = int(pct*xbt.shape[-2])
szy = int(pct*xbt.shape[-1])
stx1 = int(random.random()*(1-pct)*xbt.shape[-2])
sty1 = int(random.random()*(1-pct)*xbt.shape[-1])
stx2 = int(random.random()*(1-pct)*xbt.shape[-2])
sty2 = int(random.random()*(1-pct)*xbt.shape[-1])
stx1,sty1,stx2,sty2,szx,szy
(1, 6, 21, 3, 5, 5)
xbt[:,:,stx1:stx1+szx,sty1:sty1+szy] = xbt[:,:,stx2:stx2+szx,sty2:sty2+szy]
show_images(xbt, imsize=1.5)
#|export
def _rand_copy1(x, pct):
szx = int(pct*x.shape[-2])
szy = int(pct*x.shape[-1])
stx1 = int(random.random()*(1-pct)*x.shape[-2])
sty1 = int(random.random()*(1-pct)*x.shape[-1])
stx2 = int(random.random()*(1-pct)*x.shape[-2])
sty2 = int(random.random()*(1-pct)*x.shape[-1])
x[:,:,stx1:stx1+szx,sty1:sty1+szy] = x[:,:,stx2:stx2+szx,sty2:sty2+szy]
xb,_ = next(iter(dls.train))
xbt = xb[:16]
_rand_copy1(xbt, 0.2)
show_images(xbt, imsize=1.5)
#|export
def rand_copy(x, pct=0.2, max_num = 4):
num = random.randint(0, max_num)
for i in range(num): _rand_copy1(x, pct)
# print(num)
return x
xb,_ = next(iter(dls.train))
xbt = xb[:16]
rand_copy(xbt, 0.2, 4)
show_images(xbt, imsize=1.5)
#|export
class RandCopy(nn.Module):
def __init__(self, pct=0.2, max_num=4):
super().__init__()
self.pct,self.max_num = pct,max_num
def forward(self, x): return rand_copy(x, self.pct, self.max_num)
tfms = nn.Sequential(transforms.RandomCrop(28, padding=1),
transforms.RandomHorizontalFlip(),
RandCopy())
augcb = BatchTransformCB(partial(tfm_batch, tfm_x=tfms), on_val=False)
model = get_model()
learn = TrainLearner(model, dls, F.cross_entropy, lr=lr, cbs=[DeviceCB(), SingleBatchCB(), augcb])
learn.fit(1)
xb,yb = learn.batch
show_images(xb[:16], imsize=1.5)
set_seed(1)
epochs = 25
lr = 1e-2
tmax = epochs * len(dls.train)
sched = partial(lr_scheduler.OneCycleLR, max_lr=lr, total_steps=tmax)
xtra = [BatchSchedCB(sched), augcb]
model = get_model(act_gr, norm=nn.BatchNorm2d).apply(iw)
learn = TrainLearner(model, dls, F.cross_entropy, lr=lr, cbs=cbs+xtra, opt_func=optim.AdamW)
learn.fit(epochs)
accuracy | loss | epoch | train |
---|---|---|---|
0.734 | 0.945 | 0 | train |
0.802 | 0.641 | 0 | eval |
0.832 | 0.667 | 1 | train |
0.833 | 0.570 | 1 | eval |
0.846 | 0.577 | 2 | train |
0.855 | 0.459 | 2 | eval |
0.859 | 0.487 | 3 | train |
0.849 | 0.477 | 3 | eval |
0.870 | 0.421 | 4 | train |
0.868 | 0.387 | 4 | eval |
0.880 | 0.369 | 5 | train |
0.820 | 0.487 | 5 | eval |
0.885 | 0.340 | 6 | train |
0.888 | 0.317 | 6 | eval |
0.895 | 0.304 | 7 | train |
0.881 | 0.350 | 7 | eval |
0.903 | 0.281 | 8 | train |
0.894 | 0.305 | 8 | eval |
0.906 | 0.268 | 9 | train |
0.908 | 0.265 | 9 | eval |
0.911 | 0.253 | 10 | train |
0.909 | 0.253 | 10 | eval |
0.916 | 0.234 | 11 | train |
0.913 | 0.238 | 11 | eval |
0.919 | 0.230 | 12 | train |
0.907 | 0.256 | 12 | eval |
0.922 | 0.219 | 13 | train |
0.915 | 0.240 | 13 | eval |
0.929 | 0.200 | 14 | train |
0.913 | 0.244 | 14 | eval |
0.928 | 0.201 | 15 | train |
0.933 | 0.194 | 15 | eval |
0.933 | 0.186 | 16 | train |
0.933 | 0.195 | 16 | eval |
0.938 | 0.177 | 17 | train |
0.931 | 0.197 | 17 | eval |
0.940 | 0.168 | 18 | train |
0.936 | 0.184 | 18 | eval |
0.947 | 0.151 | 19 | train |
0.939 | 0.173 | 19 | eval |
0.949 | 0.144 | 20 | train |
0.940 | 0.169 | 20 | eval |
0.950 | 0.143 | 21 | train |
0.943 | 0.168 | 21 | eval |
0.953 | 0.134 | 22 | train |
0.943 | 0.167 | 22 | eval |
0.953 | 0.132 | 23 | train |
0.945 | 0.164 | 23 | eval |
0.954 | 0.132 | 24 | train |
0.944 | 0.163 | 24 | eval |
model2 = get_model(act_gr, norm=nn.BatchNorm2d).apply(iw)
learn2 = TrainLearner(model2, dls, F.cross_entropy, lr=lr, cbs=cbs+xtra, opt_func=optim.AdamW)
learn2.fit(epochs)
accuracy | loss | epoch | train |
---|---|---|---|
0.743 | 0.932 | 0 | train |
0.807 | 0.621 | 0 | eval |
0.831 | 0.670 | 1 | train |
0.823 | 0.585 | 1 | eval |
0.854 | 0.558 | 2 | train |
0.855 | 0.480 | 2 | eval |
0.865 | 0.478 | 3 | train |
0.875 | 0.404 | 3 | eval |
0.876 | 0.402 | 4 | train |
0.882 | 0.366 | 4 | eval |
0.877 | 0.377 | 5 | train |
0.884 | 0.350 | 5 | eval |
0.890 | 0.327 | 6 | train |
0.904 | 0.281 | 6 | eval |
0.899 | 0.296 | 7 | train |
0.885 | 0.328 | 7 | eval |
0.903 | 0.281 | 8 | train |
0.886 | 0.333 | 8 | eval |
0.908 | 0.266 | 9 | train |
0.901 | 0.278 | 9 | eval |
0.916 | 0.240 | 10 | train |
0.912 | 0.249 | 10 | eval |
0.920 | 0.227 | 11 | train |
0.909 | 0.257 | 11 | eval |
0.922 | 0.220 | 12 | train |
0.913 | 0.240 | 12 | eval |
0.926 | 0.207 | 13 | train |
0.910 | 0.246 | 13 | eval |
0.928 | 0.203 | 14 | train |
0.917 | 0.233 | 14 | eval |
0.930 | 0.195 | 15 | train |
0.922 | 0.223 | 15 | eval |
0.935 | 0.182 | 16 | train |
0.932 | 0.202 | 16 | eval |
0.937 | 0.176 | 17 | train |
0.936 | 0.188 | 17 | eval |
0.939 | 0.172 | 18 | train |
0.935 | 0.186 | 18 | eval |
0.943 | 0.162 | 19 | train |
0.938 | 0.177 | 19 | eval |
0.946 | 0.150 | 20 | train |
0.939 | 0.177 | 20 | eval |
0.950 | 0.139 | 21 | train |
0.941 | 0.170 | 21 | eval |
0.952 | 0.138 | 22 | train |
0.941 | 0.171 | 22 | eval |
0.954 | 0.134 | 23 | train |
0.942 | 0.169 | 23 | eval |
0.954 | 0.131 | 24 | train |
0.944 | 0.166 | 24 | eval |
mdl_path = Path('models')
torch.save(learn.model, mdl_path/'randcopy1.pkl')
torch.save(learn2.model, mdl_path/'randcopy2.pkl')
cp1 = CapturePreds()
learn.fit(1, train=False, cbs=cp1)
accuracy | loss | epoch | train |
---|---|---|---|
0.944 | 0.163 | 0 | eval |
cp2 = CapturePreds()
learn2.fit(1, train=False, cbs=cp2)
accuracy | loss | epoch | train |
---|---|---|---|
0.944 | 0.166 | 0 | eval |
ap = torch.stack([cp1.all_preds,cp2.all_preds]).mean(0).argmax(1)
round((ap==cp1.all_targs).float().mean().item(), 3)
0.947
p = 0.1
dist = distributions.binomial.Binomial(probs=1-p)
dist.sample((10,))
tensor([1., 0., 1., 1., 1., 1., 1., 1., 1., 1.])
class Dropout(nn.Module):
def __init__(self, p=0.1):
super().__init__()
self.p = p
def forward(self, x):
if not self.training: return x
dist = distributions.binomial.Binomial(tensor(1.0).to(x.device), probs=1-self.p)
return x * dist.sample(x.size()) * 1/(1-self.p)
def get_dropmodel(act=nn.ReLU, nfs=(16,32,64,128,256,512), norm=nn.BatchNorm2d, drop=0.0):
layers = [ResBlock(1, 16, ks=5, stride=1, act=act, norm=norm), nn.Dropout2d(drop)]
layers += [ResBlock(nfs[i], nfs[i+1], act=act, norm=norm, stride=2) for i in range(len(nfs)-1)]
layers += [nn.Flatten(), Dropout(drop), nn.Linear(nfs[-1], 10, bias=False), nn.BatchNorm1d(10)]
return nn.Sequential(*layers)
set_seed(42)
epochs=5
lr = 1e-2
tmax = epochs * len(dls.train)
sched = partial(lr_scheduler.OneCycleLR, max_lr=lr, total_steps=tmax)
xtra = [BatchSchedCB(sched)]
model = get_dropmodel(act_gr, norm=nn.BatchNorm2d, drop=0.1).apply(iw)
learn = TrainLearner(model, dls, F.cross_entropy, lr=lr, cbs=cbs+xtra, opt_func=optim.AdamW)
learn.fit(epochs)
accuracy | loss | epoch | train |
---|---|---|---|
0.809 | 0.746 | 0 | train |
0.829 | 0.557 | 0 | eval |
0.892 | 0.396 | 1 | train |
0.882 | 0.350 | 1 | eval |
0.916 | 0.280 | 2 | train |
0.911 | 0.284 | 2 | eval |
0.937 | 0.209 | 3 | train |
0.924 | 0.242 | 3 | eval |
0.956 | 0.157 | 4 | train |
0.932 | 0.223 | 4 | eval |
class TTD_CB(Callback):
def before_epoch(self, learn):
learn.model.apply(lambda m: m.train() if isinstance(m, (nn.Dropout,nn.Dropout2d)) else None)
@inplace
def transformi(b): b[xl] = [(TF.to_tensor(o)*2-1) for o in b[xl]]
tds = dsd.with_transform(transformi)
dls = DataLoaders.from_dd(tds, bs, num_workers=fc.defaults.cpus)
set_seed(42)
epochs = 20
lr = 1e-2
tmax = epochs * len(dls.train)
sched = partial(lr_scheduler.OneCycleLR, max_lr=lr, total_steps=tmax)
xtra = [BatchSchedCB(sched), augcb]
model = get_model(act_gr, norm=nn.BatchNorm2d).apply(iw)
learn = TrainLearner(model, dls, F.cross_entropy, lr=lr, cbs=cbs+xtra, opt_func=optim.AdamW)
learn.fit(epochs)
accuracy | loss | epoch | train |
---|---|---|---|
0.762 | 0.885 | 0 | train |
0.826 | 0.592 | 0 | eval |
0.853 | 0.608 | 1 | train |
0.851 | 0.511 | 1 | eval |
0.869 | 0.494 | 2 | train |
0.867 | 0.409 | 2 | eval |
0.883 | 0.402 | 3 | train |
0.885 | 0.355 | 3 | eval |
0.894 | 0.337 | 4 | train |
0.843 | 0.457 | 4 | eval |
0.901 | 0.300 | 5 | train |
0.859 | 0.402 | 5 | eval |
0.913 | 0.260 | 6 | train |
0.901 | 0.292 | 6 | eval |
0.918 | 0.240 | 7 | train |
0.891 | 0.320 | 7 | eval |
0.923 | 0.224 | 8 | train |
0.904 | 0.270 | 8 | eval |
0.929 | 0.204 | 9 | train |
0.907 | 0.265 | 9 | eval |
0.935 | 0.189 | 10 | train |
0.906 | 0.269 | 10 | eval |
0.937 | 0.184 | 11 | train |
0.923 | 0.223 | 11 | eval |
0.942 | 0.167 | 12 | train |
0.919 | 0.225 | 12 | eval |
0.947 | 0.153 | 13 | train |
0.928 | 0.209 | 13 | eval |
0.953 | 0.136 | 14 | train |
0.934 | 0.196 | 14 | eval |
0.957 | 0.125 | 15 | train |
0.936 | 0.186 | 15 | eval |
0.961 | 0.114 | 16 | train |
0.938 | 0.179 | 16 | eval |
0.965 | 0.104 | 17 | train |
0.939 | 0.185 | 17 | eval |
0.968 | 0.096 | 18 | train |
0.941 | 0.180 | 18 | eval |
0.971 | 0.091 | 19 | train |
0.941 | 0.177 | 19 | eval |
torch.save(learn.model, 'models/data_aug2.pkl')
import nbdev; nbdev.nbdev_export()