#!/usr/bin/env python # coding: utf-8 # # Reading memory in a deep network with ST Gumbel-softmax # # Using an external memory bank in a feedforward neural network. # # Author: [Jacob Reinhold](https://www.jcreinhold.com) # # Created on: Sept. 8, 2020 # ## Setup notebook # In[1]: from typing import * from functools import partial from glob import glob import os import random import sys import warnings gpu_id = 1 os.environ["CUDA_VISIBLE_DEVICES"] = f'{gpu_id}' import matplotlib.pyplot as plt import nibabel as nib import numpy as np import pandas as pd import seaborn as sns import torch from torch import Tensor from torch import nn import torch.nn.functional as F from torch.utils.data.dataset import Dataset from torch.utils.data import DataLoader # Support in-notebook plotting # In[2]: get_ipython().run_line_magic('matplotlib', 'inline') # Report versions # In[3]: print('numpy version: {}'.format(np.__version__)) from matplotlib import __version__ as mplver print('matplotlib version: {}'.format(mplver)) print(f'pytorch version: {torch.__version__}') # In[4]: pv = sys.version_info print('python version: {}.{}.{}'.format(pv.major, pv.minor, pv.micro)) # Reload packages where content for package development # In[5]: get_ipython().run_line_magic('load_ext', 'autoreload') get_ipython().run_line_magic('autoreload', '2') # Check GPU(s) # In[6]: get_ipython().system('nvidia-smi | head -n 4') # Set seeds for better reproducibility # In[7]: seed = 123 random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) # ## Setup training and validation data # # Make up some toy data, e.g., $y = \sum_{i=1}^D x_i^2 + \varepsilon$ for $\varepsilon \stackrel{iid}{\sim} \mathcal N(0,\sigma^2)$ # In[8]: def func(x): return x**2 # In[9]: class ToyRegressionData(Dataset): def __init__(self, n:int, dim:int, x_params:Tuple[float,float], noise_std:float=0.01, dist=None): self.n = n self.x_params = x_params self.dim = dim self.noise_std = noise_std if dist is None: self.x_dist = None self.x = torch.linspace(x_params[0], x_params[1], n) self.x.unsqueeze_(1) else: self.x_dist = dist(*x_params) self.x = self.x_dist.sample((n,dim)) y = torch.sum(func(self.x), axis=1, keepdim=True) self.y = y + noise_std * torch.randn_like(y) def __len__(self): return self.x.size(0) def __getitem__(self, idx:int): return (self.x[idx], self.y[idx]) # In[10]: dim = 1 N = 2**12 train_dist = None valid_dist = torch.distributions.Normal xt_params, xv_params = (-1.,1.), (0.,0.5) train_std, valid_std = 0.01, 0. print(f'N = {N}') train_ds_params = dict(x_params=xt_params, noise_std=train_std, dist=train_dist) valid_ds_params = dict(x_params=xv_params, noise_std=valid_std, dist=valid_dist) train_ds = ToyRegressionData(N, dim, **train_ds_params) valid_ds = ToyRegressionData(N//2, dim, **valid_ds_params) # In[11]: plot_data = N <= 2**12 and dim == 1 if plot_data: ixs = np.argsort(train_ds.x[:,0]) plt.plot(train_ds.x[ixs],train_ds.y[ixs],lw=1) plt.title('Training data'); # ## Setup model # In[12]: def sample_gumbel(logits:Tensor, eps:float=1e-8): U = torch.rand_like(logits) return -torch.log(-torch.log(U + eps) + eps) def sample_gumbel_softmax(logits:Tensor, temperature:float): y = logits + sample_gumbel(logits) return F.softmax(y / temperature, dim=-1) def gumbel_softmax(logits:Tensor, temperature:float=0.67): y = sample_gumbel_softmax(logits, temperature) shape = y.size() _, ind = y.max(dim=-1) y_hard = torch.zeros_like(y).view(-1, shape[-1]) y_hard.scatter_(1, ind.view(-1, 1), 1) y_hard = y_hard.view(*shape) return (y_hard - y).detach() + y # In[13]: def to_numpy(x:Tensor): return x.detach().cpu().numpy() # Example of how Gumbel-softmax sampling works with a given distribution and temperature variable. (Not directly related to the reading from memory task.) # In[14]: x = [1, 2, 3, 4] probit = torch.tensor([0.1, 0.2, 0.6, 0.1]) temperatures = [0.1, 0.25, 0.5, 2.0] f, axs = plt.subplots(1, len(temperatures)+1, sharex=True, sharey=True, figsize=(16,4)) axs[0].bar(x, probit) axs[0].set_title("True distribution") for t, ax in zip(temperatures, axs[1:]): sample = sample_gumbel_softmax(torch.log(probit), t) sample = to_numpy(sample) ax.bar(x, sample) ax.set_title(r"$\tau = " + f"{t}$") # In[15]: activation = partial(nn.ReLU, inplace=True) def linear(in_features:int, out_features:int, dropout_rate:float=0.): layers = [nn.Linear(in_features, out_features, bias=False), nn.BatchNorm1d(out_features), activation()] if dropout_rate > 0.: layers.append(nn.Dropout(dropout_rate)) return layers class MemoryTensor(nn.Module): def __init__(self, x:Tensor): super().__init__() x.unsqueeze_(1) self.memory = x def __getitem__(self, idx:Tensor): if self.training: idx = gumbel_softmax(idx) out = idx @ self.memory else: idx = torch.argmax(idx, dim=1) out = self.memory[idx] return out class MemoryNet(nn.Module): def __init__(self, dim:int, mid:int, memory:Tensor, n_layers:int=2, dropout_rate:float=0., out_bias:bool=True): super().__init__() self.memory = MemoryTensor(memory) mem_size = len(memory) layers = linear(dim, mid, dropout_rate) for _ in range(n_layers-1): layers.extend( linear(mid, mid, dropout_rate=dropout_rate)) self.h = nn.Sequential(*layers) self.o = nn.Linear(mid, mem_size, bias=out_bias) def forward(self, x): idx = self.o(self.h(x)) x = self.memory[idx] return x # ## Setup training # Set training and model hyperparameters # In[16]: device = torch.device(f'cuda') torch.backends.cudnn.benchmark = True # In[17]: # system setup load_model = False # logging setup log_rate = 100 # print losses every log_rate epoch version = 'v1' # naming scheme of model to load # model, optimizer, loss, and training parameters batch_size = N // 2 n_jobs = 0 n_epochs = 200 mem_size = 5 memory = torch.linspace(0., 1., steps=mem_size, device=device) model_kwargs = dict(dim=dim, mid=48, memory=memory, n_layers=4, dropout_rate=0., out_bias=True) use_adam = True lr = 1e-3 weight_decay = 4e-3 momentum = 0.9 optim_kwargs = dict(lr=lr, betas=(momentum,0.99), weight_decay=weight_decay) if use_adam else \ dict(lr=lr, momentum=momentum, weight_decay=weight_decay) use_scheduler = True scheduler_kwargs = dict(step_size=50, gamma=0.5) use_l2 = True # use l2 (MSE) loss or l1 loss function # In[18]: train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=n_jobs, pin_memory=True) valid_loader = DataLoader(valid_ds, batch_size=batch_size, num_workers=n_jobs, pin_memory=True) # In[19]: model = MemoryNet(**model_kwargs) # In[20]: def num_params(model): return sum(p.numel() for p in model.parameters() if p.requires_grad) # In[21]: print(f'Number of trainable parameters in model: {num_params(model)}') # In[22]: if load_model: state_dict = torch.load(f'trained_{version}.pth') model.load_state_dict(state_dict); # In[23]: model.cuda(device=device) optim_cls = torch.optim.AdamW if use_adam else torch.optim.SGD optimizer = optim_cls(model.parameters(), **optim_kwargs) if use_scheduler: scheduler = torch.optim.lr_scheduler.StepLR(optimizer, **scheduler_kwargs) criterion = nn.MSELoss() if use_l2 else nn.SmoothL1Loss() # ## Train model # In[24]: train_losses, valid_losses = [], [] n_batches = len(train_loader) # In[25]: for t in range(1, n_epochs+1): # training t_losses = [] model.train() for i, (x, y) in enumerate(train_loader): x, y = x.to(device), y.to(device) optimizer.zero_grad() out = model(x) loss = criterion(out, y) t_losses.append(loss.item()) loss.backward() optimizer.step() train_losses.append(t_losses) # validation v_losses = [] model.eval() with torch.no_grad(): for i, (x, y) in enumerate(valid_loader): x, y = x.to(device), y.to(device) out = model(x) loss = criterion(out, y) v_losses.append(loss.item()) valid_losses.append(v_losses) if not np.all(np.isfinite(t_losses)): raise RuntimeError('NaN or Inf in training loss, cannot recover. Exiting.') if t % log_rate == 0: log = (f'Epoch: {t} - TL: {np.mean(t_losses):.2e}, VL: {np.mean(v_losses):.2e}') print(log) if use_scheduler: scheduler.step() # ## Analyze training # In[26]: def tidy_losses(train:list, valid:list): out = {'epoch': [], 'type': [], 'value': [], 'phase': []} for i, (tl,vl) in enumerate(zip(train,valid),1): for tli in tl: out['epoch'].append(i) out['type'].append('loss') out['value'].append(tli) out['phase'].append('train') for vli in vl: out['epoch'].append(i) out['type'].append('loss') out['value'].append(vli) out['phase'].append('valid') return pd.DataFrame(out) # In[27]: losses = tidy_losses(train_losses, valid_losses) # In[28]: f, ax1 = plt.subplots(1,1,figsize=(12, 8),sharey=True) sns.lineplot(x='epoch',y='value',hue='phase',data=losses,ci='sd',ax=ax1,lw=3); ax1.set_yscale('log'); ax1.set_title('Losses'); f.savefig(f'losses_{version}.pdf') # In[29]: save_tidy = False if save_tidy: losses.to_csv(f'losses_{version}.csv') # In[30]: save_model = True if save_model and not load_model: torch.save(model.state_dict(), f'trained_{version}.pth') # ## Examine and evaluate results # In[31]: xt, yt = train_ds.x, train_ds.y xv, yv = valid_ds.x, valid_ds.y xtn, ytn = xt.numpy(), yt.numpy() xvn, yvn = xv.numpy(), yv.numpy() # In[32]: model.eval() with torch.no_grad(): yhattn = to_numpy(model(xt.to(device))) yhatvn = to_numpy(model(xv.to(device))) # In[33]: tss, vss = 1, 1 xtns, xvns = xtn[::tss,0], xvn[::vss,0] xr = np.abs(train_ds.x_params[1]) xtm = (xtns > -xr) & (xtns < xr) xvm = (xvns > -xr) & (xvns < xr) fig,(ax1,ax2,ax3,ax4) = plt.subplots(1,4,figsize=(15,5),sharex=True) with warnings.catch_warnings(): warnings.simplefilter("ignore") sns.regplot(xtns[xtm], ytn[::tss,0][xtm], order=2, ax=ax1) sns.regplot(xvns[xvm], yvn[::vss,0][xvm], order=2, ax=ax2) sns.regplot(xtns[xtm], yhattn[::tss,0][xtm], order=2, ax=ax3) sns.regplot(xvns[xvm], yhatvn[::vss,0][xvm], order=2, ax=ax4) ax1.set_title('Training data'); ax2.set_title('Validation data'); ax3.set_title('Predicted training'); ax4.set_title('Predicted validation'); ax1.set_xlim((-xr,xr)); for ax in (ax1,ax2,ax3,ax4): ax.set_ylim((-0.1,xr**2)) fig.savefig(f'predicted_{version}.pdf') # ### Examine the function fit by the NN # In[34]: def optimal_func(x, memory): i = np.argmin(np.abs(func(x) - memory), axis=1) return memory[:,i].squeeze() # In[35]: if plot_data: fig, ax1 = plt.subplots(1,1,figsize=(8,8)) x = train_ds.x[:,0] ixs = np.argsort(x) x_range = (-2, 2) xs = torch.linspace(*x_range, 250).unsqueeze(1) with torch.no_grad(): ys = to_numpy(model(xs.to(device))) x, y = xs[:,0], ys[:,0] pal = sns.color_palette() ax1.plot(train_ds.x[ixs],train_ds.y[ixs], alpha=0.5,color=pal[1],label='Truth') ax1.plot(x,y,label='Fit',lw=2,color=pal[0],linestyle='dashed') mem = memory.cpu().detach().numpy().T ax1.plot(train_ds.x[ixs], optimal_func(train_ds.x[ixs], mem), alpha=0.5,lw=3,color=pal[2],label='Best Fit') ax1.legend(loc='best'); fig.savefig(f'fit_{version}.pdf') # In[36]: print('Elements from memory used: ') print(' '.join([f'{x:0.3f}' for x in np.unique(y).tolist()])) # In[ ]: