#|default_exp conv #|export import torch from torch import nn from torch.utils.data import default_collate from typing import Mapping from miniai.training import * from miniai.datasets import * import pickle,gzip,math,os,time,shutil,torch,matplotlib as mpl, numpy as np import pandas as pd,matplotlib.pyplot as plt from pathlib import Path from torch import tensor from torch.utils.data import DataLoader from typing import Mapping mpl.rcParams['image.cmap'] = 'gray' path_data = Path('data') path_gz = path_data/'mnist.pkl.gz' with gzip.open(path_gz, 'rb') as f: ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding='latin-1') x_train, y_train, x_valid, y_valid = map(tensor, [x_train, y_train, x_valid, y_valid]) x_imgs = x_train.view(-1,28,28) xv_imgs = x_valid.view(-1,28,28) mpl.rcParams['figure.dpi'] = 30 im3 = x_imgs[7] show_image(im3); top_edge = tensor([[-1,-1,-1], [ 0, 0, 0], [ 1, 1, 1]]).float() show_image(top_edge, noframe=False); df = pd.DataFrame(im3[:13,:23]) df.style.format(precision=2).set_properties(**{'font-size':'7pt'}).background_gradient('Greys') (im3[3:6,14:17] * top_edge).sum() (im3[7:10,14:17] * top_edge).sum() def apply_kernel(row, col, kernel): return (im3[row-1:row+2,col-1:col+2] * kernel).sum() apply_kernel(4,15,top_edge) [[(i,j) for j in range(5)] for i in range(5)] rng = range(1,27) top_edge3 = tensor([[apply_kernel(i,j,top_edge) for j in rng] for i in rng]) show_image(top_edge3); left_edge = tensor([[-1,0,1], [-1,0,1], [-1,0,1]]).float() show_image(left_edge, noframe=False); left_edge3 = tensor([[apply_kernel(i,j,left_edge) for j in rng] for i in rng]) show_image(left_edge3); import torch.nn.functional as F import torch inp = im3[None,None,:,:].float() inp_unf = F.unfold(inp, (3,3))[0] inp_unf.shape w = left_edge.view(-1) w.shape out_unf = w@inp_unf out_unf.shape out = out_unf.view(26,26) show_image(out); %timeit -n 1 tensor([[apply_kernel(i,j,left_edge) for j in rng] for i in rng]); %timeit -n 100 (w@F.unfold(inp, (3,3))[0]).view(26,26); %timeit -n 100 F.conv2d(inp, left_edge[None,None]) diag1_edge = tensor([[ 0,-1, 1], [-1, 1, 0], [ 1, 0, 0]]).float() show_image(diag1_edge, noframe=False); diag2_edge = tensor([[ 1,-1, 0], [ 0, 1,-1], [ 0, 0, 1]]).float() show_image(diag2_edge, noframe=False); xb = x_imgs[:16][:,None] xb.shape edge_kernels = torch.stack([left_edge, top_edge, diag1_edge, diag2_edge])[:,None] edge_kernels.shape batch_features = F.conv2d(xb, edge_kernels) batch_features.shape img0 = xb[1,0] show_image(img0); show_images([batch_features[1,i] for i in range(4)]) n,m = x_train.shape c = y_train.max()+1 nh = 50 model = nn.Sequential(nn.Linear(m,nh), nn.ReLU(), nn.Linear(nh,10)) broken_cnn = nn.Sequential( nn.Conv2d(1,30, kernel_size=3, padding=1), nn.ReLU(), nn.Conv2d(30,10, kernel_size=3, padding=1) ) broken_cnn(xb).shape #|export def conv(ni, nf, ks=3, stride=2, act=True): res = nn.Conv2d(ni, nf, stride=stride, kernel_size=ks, padding=ks//2) if act: res = nn.Sequential(res, nn.ReLU()) return res simple_cnn = nn.Sequential( conv(1 ,4), #14x14 conv(4 ,8), #7x7 conv(8 ,16), #4x4 conv(16,16), #2x2 conv(16,10, act=False), #1x1 nn.Flatten(), ) simple_cnn(xb).shape x_imgs = x_train.view(-1,1,28,28) xv_imgs = x_valid.view(-1,1,28,28) train_ds,valid_ds = Dataset(x_imgs, y_train),Dataset(xv_imgs, y_valid) #|export def_device = 'mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu' def to_device(x, device=def_device): if isinstance(x, torch.Tensor): return x.to(device) if isinstance(x, Mapping): return {k:v.to(device) for k,v in x.items()} return type(x)(to_device(o, device) for o in x) def collate_device(b): return to_device(default_collate(b)) from torch import optim bs = 256 lr = 0.4 train_dl,valid_dl = get_dls(train_ds, valid_ds, bs, collate_fn=collate_device) opt = optim.SGD(simple_cnn.parameters(), lr=lr) loss,acc = fit(5, simple_cnn.to(def_device), F.cross_entropy, opt, train_dl, valid_dl) opt = optim.SGD(simple_cnn.parameters(), lr=lr/4) loss,acc = fit(5, simple_cnn.to(def_device), F.cross_entropy, opt, train_dl, valid_dl) simple_cnn[0][0] conv1 = simple_cnn[0][0] conv1.weight.shape conv1.bias.shape from torchvision.io import read_image im = read_image('images/grizzly.jpg') im.shape show_image(im.permute(1,2,0)); _,axs = plt.subplots(1,3) for bear,ax,color in zip(im,axs,('Reds','Greens','Blues')): show_image(255-bear, ax=ax, cmap=color) import nbdev; nbdev.nbdev_export()