#| default_exp activations
#|export
from __future__ import annotations
import random,math,torch,matplotlib as mpl,numpy as np,matplotlib.pyplot as plt
import fastcore.all as fc
from pathlib import Path
from operator import attrgetter,itemgetter
from functools import partial
from torch import tensor,nn,optim
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from datasets import load_dataset
from miniai.datasets import *
from miniai.learner import *
from fastcore.test import test_close
torch.set_printoptions(precision=2, linewidth=140, sci_mode=False)
mpl.rcParams['figure.constrained_layout.use'] = True
import logging
logging.disable(logging.WARNING)
torch.use_deterministic_algorithms(True)
#|export
def set_seed(seed):
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
x,y = 'image','label'
name = "fashion_mnist"
dsd = load_dataset(name)
bs = 1024
@inplace
def transformi(b): b[x] = [TF.to_tensor(o) for o in b[x]]
tds = dsd.with_transform(transformi)
dls = DataLoaders.from_dd(tds, bs)
dt = dls.train
0%| | 0/2 [00:00<?, ?it/s]
def conv(ni, nf, ks=3, act=True):
res = nn.Conv2d(ni, nf, stride=2, kernel_size=ks, padding=ks//2)
if act: res = nn.Sequential(res, nn.ReLU())
return res
By default, we use a 3×3-pixel kernel. If we use eight filters then we will be using nine pixels to calculate eight numbers. That means it isn't really learning much at all. But if we use less filters, then there's not much being done in this first layer.
To fix this, we can use a larger kernel in the first layer.
def cnn_layers():
return [
conv(1 ,8, ks=5), #14x14
conv(8 ,16), #7x7
conv(16,32), #4x4
conv(32,64), #2x2
conv(64,10, act=False), #1x1
nn.Flatten()]
We want to train quickly, so that means training at a high learning rate.
cbs = [TrainCB(), CudaCB(), MetricsCB(Accuracy()), ProgressCB(plot=True)]
def fit(model, epochs=1):
learn = Learner(model, dls, loss_func=F.cross_entropy, lr=0.6, cbs=cbs)
learn.fit(epochs)
return learn
set_seed(1)
learn = fit(nn.Sequential(*cnn_layers()))
2.199 :: 0.147
class SequentialModel(nn.Module):
def __init__(self, *layers):
super().__init__()
self.layers = nn.ModuleList(layers)
self.act_means = [[] for _ in layers]
self.act_stds = [[] for _ in layers]
def __call__(self, x):
for i,l in enumerate(self.layers):
x = l(x)
self.act_means[i].append(x.data.cpu().mean())
self.act_stds [i].append(x.data.cpu().std ())
return x
def __iter__(self): return iter(self.layers)
set_seed(1)
model = SequentialModel(*cnn_layers())
learn = fit(model)
2.199 :: 0.147
for l in model.act_means: plt.plot(l)
plt.legend(range(5));
for l in model.act_stds: plt.plot(l)
plt.legend(range(5));
Hooks are PyTorch object you can add to any nn.Module. A hook will be called when a layer, it is registered to, is executed during the forward pass (forward hook) or the backward pass (backward hook). Hooks don't require us to rewrite the model.
set_seed(1)
model = nn.Sequential(*cnn_layers())
A hook is attached to a layer, and needs to have a function that takes three arguments: module, input, output. Here we store the mean and std of the output in the correct position of our list.
act_means = [[] for _ in model]
act_stds = [[] for _ in model]
def append_stats(i, mod, inp, outp):
act_means[i].append(outp.data.cpu().mean())
act_stds [i].append(outp.data.cpu().std())
for i,m in enumerate(model): m.register_forward_hook(partial(append_stats, i))
fit(model)
2.199 :: 0.147
<miniai.learner.Learner>
for o in act_means: plt.plot(o)
plt.legend(range(5));
We can refactor this in a Hook class. It's very important to remove the hooks when they are deleted, otherwise there will be references kept and the memory won't be properly released when your model is deleted.
#| export
class Hook():
def __init__(self, m, f): self.hook = m.register_forward_hook(partial(f, self))
def remove(self): self.hook.remove()
def __del__(self): self.remove()
def append_stats(hook, mod, inp, outp):
if not hasattr(hook,'stats'): hook.stats = ([],[])
acts = outp.data.cpu()
hook.stats[0].append(acts.mean())
hook.stats[1].append(acts.std())
set_seed(1)
model = nn.Sequential(*cnn_layers())
hooks = [Hook(l, append_stats) for l in model[:5].children()]
learn = fit(model)
2.199 :: 0.147
for h in hooks:
plt.plot(h.stats[0])
h.remove()
plt.legend(range(5));
#| export
class Hooks(list):
def __init__(self, ms, f): super().__init__([Hook(m, f) for m in ms])
def __enter__(self, *args): return self
def __exit__ (self, *args): self.remove()
def __del__(self): self.remove()
def __delitem__(self, i):
self[i].remove()
super().__delitem__(i)
def remove(self):
for h in self: h.remove()
set_seed(1)
model = nn.Sequential(*cnn_layers())
with Hooks(model, append_stats) as hooks:
fit(model)
fig,axs = plt.subplots(1,2, figsize=(10,4))
for h in hooks:
for i in 0,1: axs[i].plot(h.stats[i])
plt.legend(range(6));
2.199 :: 0.147
def append_stats(hook, mod, inp, outp):
if not hasattr(hook,'stats'): hook.stats = ([],[],[])
acts = outp.data.cpu()
hook.stats[0].append(acts.mean())
hook.stats[1].append(acts.std())
hook.stats[2].append(acts.histc(40,0,10))
set_seed(1)
model = nn.Sequential(*cnn_layers())
with Hooks(model, append_stats) as hooks: fit(model)
2.199 :: 0.147
# Thanks to @ste for initial version of histgram plotting code
def get_hist(h): return torch.stack(h.stats[2]).t().float().log1p()
fig,axes = subplots(2,2, figsize=(9,5))
for ax,h in zip(axes.flatten(), hooks[:4]):
show_image(get_hist(h), ax, origin='lower')
def get_min(h):
h1 = torch.stack(h.stats[2]).t().float()
return h1[:2].sum(0)/h1.sum(0)
fig,axes = subplots(2,2, figsize=(15,6))
for ax,h in zip(axes.flatten(), hooks[:4]):
ax.plot(get_min(h))
ax.set_ylim(0,1)
import nbdev; nbdev.nbdev_export()