Using an external memory bank in a feedforward neural network.
Author: Jacob Reinhold
Created on: Sept. 8, 2020
from typing import *
from functools import partial
from glob import glob
import os
import random
import sys
import warnings
gpu_id = 1
os.environ["CUDA_VISIBLE_DEVICES"] = f'{gpu_id}'
import matplotlib.pyplot as plt
import nibabel as nib
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from torch import Tensor
from torch import nn
import torch.nn.functional as F
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
Support in-notebook plotting
%matplotlib inline
Report versions
print('numpy version: {}'.format(np.__version__))
from matplotlib import __version__ as mplver
print('matplotlib version: {}'.format(mplver))
print(f'pytorch version: {torch.__version__}')
numpy version: 1.19.1 matplotlib version: 3.2.2 pytorch version: 1.6.0
pv = sys.version_info
print('python version: {}.{}.{}'.format(pv.major, pv.minor, pv.micro))
python version: 3.8.5
Reload packages where content for package development
%load_ext autoreload
%autoreload 2
Check GPU(s)
!nvidia-smi | head -n 4
Wed Sep 9 11:56:26 2020 +-----------------------------------------------------------------------------+ | NVIDIA-SMI 430.40 Driver Version: 430.40 CUDA Version: 10.1 | |-------------------------------+----------------------+----------------------+
Set seeds for better reproducibility
seed = 123
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
Make up some toy data, e.g., $y = \sum_{i=1}^D x_i^2 + \varepsilon$ for $\varepsilon \stackrel{iid}{\sim} \mathcal N(0,\sigma^2)$
def func(x):
return x**2
class ToyRegressionData(Dataset):
def __init__(self, n:int, dim:int,
x_params:Tuple[float,float],
noise_std:float=0.01,
dist=None):
self.n = n
self.x_params = x_params
self.dim = dim
self.noise_std = noise_std
if dist is None:
self.x_dist = None
self.x = torch.linspace(x_params[0], x_params[1], n)
self.x.unsqueeze_(1)
else:
self.x_dist = dist(*x_params)
self.x = self.x_dist.sample((n,dim))
y = torch.sum(func(self.x), axis=1, keepdim=True)
self.y = y + noise_std * torch.randn_like(y)
def __len__(self):
return self.x.size(0)
def __getitem__(self, idx:int):
return (self.x[idx], self.y[idx])
dim = 1
N = 2**12
train_dist = None
valid_dist = torch.distributions.Normal
xt_params, xv_params = (-1.,1.), (0.,0.5)
train_std, valid_std = 0.01, 0.
print(f'N = {N}')
train_ds_params = dict(x_params=xt_params, noise_std=train_std, dist=train_dist)
valid_ds_params = dict(x_params=xv_params, noise_std=valid_std, dist=valid_dist)
train_ds = ToyRegressionData(N, dim, **train_ds_params)
valid_ds = ToyRegressionData(N//2, dim, **valid_ds_params)
N = 4096
plot_data = N <= 2**12 and dim == 1
if plot_data:
ixs = np.argsort(train_ds.x[:,0])
plt.plot(train_ds.x[ixs],train_ds.y[ixs],lw=1)
plt.title('Training data');
def sample_gumbel(logits:Tensor, eps:float=1e-8):
U = torch.rand_like(logits)
return -torch.log(-torch.log(U + eps) + eps)
def sample_gumbel_softmax(logits:Tensor, temperature:float):
y = logits + sample_gumbel(logits)
return F.softmax(y / temperature, dim=-1)
def gumbel_softmax(logits:Tensor, temperature:float=0.67):
y = sample_gumbel_softmax(logits, temperature)
shape = y.size()
_, ind = y.max(dim=-1)
y_hard = torch.zeros_like(y).view(-1, shape[-1])
y_hard.scatter_(1, ind.view(-1, 1), 1)
y_hard = y_hard.view(*shape)
return (y_hard - y).detach() + y
def to_numpy(x:Tensor):
return x.detach().cpu().numpy()
Example of how Gumbel-softmax sampling works with a given distribution and temperature variable. (Not directly related to the reading from memory task.)
x = [1, 2, 3, 4]
probit = torch.tensor([0.1, 0.2, 0.6, 0.1])
temperatures = [0.1, 0.25, 0.5, 2.0]
f, axs = plt.subplots(1, len(temperatures)+1,
sharex=True, sharey=True,
figsize=(16,4))
axs[0].bar(x, probit)
axs[0].set_title("True distribution")
for t, ax in zip(temperatures, axs[1:]):
sample = sample_gumbel_softmax(torch.log(probit), t)
sample = to_numpy(sample)
ax.bar(x, sample)
ax.set_title(r"$\tau = " + f"{t}$")
activation = partial(nn.ReLU, inplace=True)
def linear(in_features:int, out_features:int, dropout_rate:float=0.):
layers = [nn.Linear(in_features, out_features, bias=False),
nn.BatchNorm1d(out_features),
activation()]
if dropout_rate > 0.:
layers.append(nn.Dropout(dropout_rate))
return layers
class MemoryTensor(nn.Module):
def __init__(self, x:Tensor):
super().__init__()
x.unsqueeze_(1)
self.memory = x
def __getitem__(self, idx:Tensor):
if self.training:
idx = gumbel_softmax(idx)
out = idx @ self.memory
else:
idx = torch.argmax(idx, dim=1)
out = self.memory[idx]
return out
class MemoryNet(nn.Module):
def __init__(self, dim:int, mid:int,
memory:Tensor, n_layers:int=2,
dropout_rate:float=0.,
out_bias:bool=True):
super().__init__()
self.memory = MemoryTensor(memory)
mem_size = len(memory)
layers = linear(dim, mid, dropout_rate)
for _ in range(n_layers-1):
layers.extend(
linear(mid, mid, dropout_rate=dropout_rate))
self.h = nn.Sequential(*layers)
self.o = nn.Linear(mid, mem_size, bias=out_bias)
def forward(self, x):
idx = self.o(self.h(x))
x = self.memory[idx]
return x
Set training and model hyperparameters
device = torch.device(f'cuda')
torch.backends.cudnn.benchmark = True
# system setup
load_model = False
# logging setup
log_rate = 100 # print losses every log_rate epoch
version = 'v1' # naming scheme of model to load
# model, optimizer, loss, and training parameters
batch_size = N // 2
n_jobs = 0
n_epochs = 200
mem_size = 5
memory = torch.linspace(0., 1., steps=mem_size, device=device)
model_kwargs = dict(dim=dim, mid=48,
memory=memory,
n_layers=4,
dropout_rate=0.,
out_bias=True)
use_adam = True
lr = 1e-3
weight_decay = 4e-3
momentum = 0.9
optim_kwargs = dict(lr=lr, betas=(momentum,0.99), weight_decay=weight_decay) if use_adam else \
dict(lr=lr, momentum=momentum, weight_decay=weight_decay)
use_scheduler = True
scheduler_kwargs = dict(step_size=50, gamma=0.5)
use_l2 = True # use l2 (MSE) loss or l1 loss function
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
num_workers=n_jobs, pin_memory=True)
valid_loader = DataLoader(valid_ds, batch_size=batch_size,
num_workers=n_jobs, pin_memory=True)
model = MemoryNet(**model_kwargs)
def num_params(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Number of trainable parameters in model: {num_params(model)}')
Number of trainable parameters in model: 7589
if load_model:
state_dict = torch.load(f'trained_{version}.pth')
model.load_state_dict(state_dict);
model.cuda(device=device)
optim_cls = torch.optim.AdamW if use_adam else torch.optim.SGD
optimizer = optim_cls(model.parameters(), **optim_kwargs)
if use_scheduler:
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, **scheduler_kwargs)
criterion = nn.MSELoss() if use_l2 else nn.SmoothL1Loss()
train_losses, valid_losses = [], []
n_batches = len(train_loader)
for t in range(1, n_epochs+1):
# training
t_losses = []
model.train()
for i, (x, y) in enumerate(train_loader):
x, y = x.to(device), y.to(device)
optimizer.zero_grad()
out = model(x)
loss = criterion(out, y)
t_losses.append(loss.item())
loss.backward()
optimizer.step()
train_losses.append(t_losses)
# validation
v_losses = []
model.eval()
with torch.no_grad():
for i, (x, y) in enumerate(valid_loader):
x, y = x.to(device), y.to(device)
out = model(x)
loss = criterion(out, y)
v_losses.append(loss.item())
valid_losses.append(v_losses)
if not np.all(np.isfinite(t_losses)):
raise RuntimeError('NaN or Inf in training loss, cannot recover. Exiting.')
if t % log_rate == 0:
log = (f'Epoch: {t} - TL: {np.mean(t_losses):.2e}, VL: {np.mean(v_losses):.2e}')
print(log)
if use_scheduler: scheduler.step()
Epoch: 100 - TL: 1.82e-02, VL: 3.16e-02 Epoch: 200 - TL: 1.17e-02, VL: 2.28e-02
def tidy_losses(train:list, valid:list):
out = {'epoch': [], 'type': [], 'value': [], 'phase': []}
for i, (tl,vl) in enumerate(zip(train,valid),1):
for tli in tl:
out['epoch'].append(i)
out['type'].append('loss')
out['value'].append(tli)
out['phase'].append('train')
for vli in vl:
out['epoch'].append(i)
out['type'].append('loss')
out['value'].append(vli)
out['phase'].append('valid')
return pd.DataFrame(out)
losses = tidy_losses(train_losses, valid_losses)
f, ax1 = plt.subplots(1,1,figsize=(12, 8),sharey=True)
sns.lineplot(x='epoch',y='value',hue='phase',data=losses,ci='sd',ax=ax1,lw=3);
ax1.set_yscale('log');
ax1.set_title('Losses');
f.savefig(f'losses_{version}.pdf')
save_tidy = False
if save_tidy:
losses.to_csv(f'losses_{version}.csv')
save_model = True
if save_model and not load_model:
torch.save(model.state_dict(), f'trained_{version}.pth')
xt, yt = train_ds.x, train_ds.y
xv, yv = valid_ds.x, valid_ds.y
xtn, ytn = xt.numpy(), yt.numpy()
xvn, yvn = xv.numpy(), yv.numpy()
model.eval()
with torch.no_grad():
yhattn = to_numpy(model(xt.to(device)))
yhatvn = to_numpy(model(xv.to(device)))
tss, vss = 1, 1
xtns, xvns = xtn[::tss,0], xvn[::vss,0]
xr = np.abs(train_ds.x_params[1])
xtm = (xtns > -xr) & (xtns < xr)
xvm = (xvns > -xr) & (xvns < xr)
fig,(ax1,ax2,ax3,ax4) = plt.subplots(1,4,figsize=(15,5),sharex=True)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
sns.regplot(xtns[xtm], ytn[::tss,0][xtm], order=2, ax=ax1)
sns.regplot(xvns[xvm], yvn[::vss,0][xvm], order=2, ax=ax2)
sns.regplot(xtns[xtm], yhattn[::tss,0][xtm], order=2, ax=ax3)
sns.regplot(xvns[xvm], yhatvn[::vss,0][xvm], order=2, ax=ax4)
ax1.set_title('Training data');
ax2.set_title('Validation data');
ax3.set_title('Predicted training');
ax4.set_title('Predicted validation');
ax1.set_xlim((-xr,xr));
for ax in (ax1,ax2,ax3,ax4): ax.set_ylim((-0.1,xr**2))
fig.savefig(f'predicted_{version}.pdf')
def optimal_func(x, memory):
i = np.argmin(np.abs(func(x) - memory), axis=1)
return memory[:,i].squeeze()
if plot_data:
fig, ax1 = plt.subplots(1,1,figsize=(8,8))
x = train_ds.x[:,0]
ixs = np.argsort(x)
x_range = (-2, 2)
xs = torch.linspace(*x_range, 250).unsqueeze(1)
with torch.no_grad():
ys = to_numpy(model(xs.to(device)))
x, y = xs[:,0], ys[:,0]
pal = sns.color_palette()
ax1.plot(train_ds.x[ixs],train_ds.y[ixs],
alpha=0.5,color=pal[1],label='Truth')
ax1.plot(x,y,label='Fit',lw=2,color=pal[0],linestyle='dashed')
mem = memory.cpu().detach().numpy().T
ax1.plot(train_ds.x[ixs], optimal_func(train_ds.x[ixs], mem),
alpha=0.5,lw=3,color=pal[2],label='Best Fit')
ax1.legend(loc='best');
fig.savefig(f'fit_{version}.pdf')
print('Elements from memory used: ')
print(' '.join([f'{x:0.3f}' for x in np.unique(y).tolist()]))
Elements from memory used: 0.000 0.250 0.500 0.750 1.000