import sys
import logging
import numpy as np
import scipy as sp
import sklearn
import statsmodels.api as sm
from statsmodels.formula.api import ols
%load_ext autoreload
%autoreload 2
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
import seaborn as sns
sns.set_context("poster")
sns.set(rc={'figure.figsize': (16, 9.)})
sns.set_style("whitegrid")
import pandas as pd
pd.set_option("display.max_rows", 120)
pd.set_option("display.max_columns", 120)
logging.basicConfig(level=logging.INFO, stream=sys.stdout)
import torch
import torch.nn as nn
import torch.nn.functional as F
import mlp
import snalu
SEED = 123
np.random.seed(SEED)
torch.manual_seed(SEED);
def create_data(min_val, max_val, n_elts, fun_op, fun_name, single_dim=False):
if single_dim:
if fun_name == 'sqrt':
min_val = 0
x = torch.randint(low=min_val, high=max_val + 1, size=(n_elts, 1)).float()
y = fun_op(x).reshape(-1)
else:
x = torch.randint(low=min_val, high=max_val + 1, size=(n_elts, 2)).float()
if fun_name == 'div':
x = x[torch.nonzero(x[:, 1]).squeeze()]
y = fun_op(x[:, 0], x[:, 1])
return x, y
def split_data(data, less, greater, test_percentage=0.2):
x, y = data
inter = torch.nonzero(((x >= less) & (x <= greater)).all(dim=1))
extra_less = torch.nonzero((x < less).any(dim=1))
extra_greater = torch.nonzero((x > greater).any(dim=1))
extra = torch.cat([extra_less, extra_greater], dim=0)
x_extra = torch.index_select(x, 0, extra.squeeze())
y_extra = torch.index_select(y, 0, extra.squeeze())
x_inter = torch.index_select(x, 0, inter.squeeze())
y_inter = torch.index_select(y, 0, inter.squeeze())
cutoff = int((1.0 - test_percentage) * x_inter.shape[0])
x_inter_train = x_inter[:cutoff]
x_inter_test = x_inter[cutoff:]
y_inter_train = y_inter[:cutoff]
y_inter_test = y_inter[cutoff:]
return (x_inter_train, y_inter_train), (x_inter_test, y_inter_test), (x_extra, y_extra)
def train(model, data, n_epochs, optimizer, lr, verbose=False):
opt = optimizer(model.parameters(), lr=lr)
x, y = data
early_break = 0
early_break_max = 70000
for epoch in range(n_epochs):
pred = model(x).reshape(-1)
mse = F.mse_loss(pred, y)
mae = torch.mean(torch.abs(pred - y))
if mse < 0.05 and mae < 0.05:
early_break += 1
if early_break >= early_break_max:
break
else:
early_break = 0
opt.zero_grad()
mse.backward()
opt.step()
if verbose and epoch % 50000 == 0:
print(f'Epoch: {epoch}: mse={round(mse.item(), 2)}; mae={round(mae.item(), 2)}')
def test(model, data):
x, y = data
pred = model(x).reshape(-1)
mse = F.mse_loss(pred, y)
mae = torch.mean(torch.abs(pred - y))
return round(mse.item(), 2), round(mae.item(), 2)
fun_dict = {
'add': lambda x, y: x + y,
'sub': lambda x, y: x - y,
'mul': lambda x, y: x * y,
'div': lambda x, y: x / y,
'sqr': lambda x: torch.pow(x, 2),
'sqrt': lambda x: torch.sqrt(x)
}
models = {
'tanh': nn.Tanh,
'sigmoid': nn.Sigmoid,
'relu6': nn.ReLU6,
'softsign': nn.Softsign,
'selu': nn.SELU,
'elu': nn.ELU,
'relu': nn.ReLU,
'none': None,
'NAC': None,
'NALU': None,
'SNALU': None,
}
N_LAYERS = 2
OUT_DIM = 1
HIDDEN_DIM = 2
N_EPOCHS = int(6e5)
OPTIMIZER = torch.optim.Adam
LR = 0.0001
DATA_RANGE = (-200, 200)
LESS_THAN = -100
GREATER_THAN = 100
N_ELTS = 1000
interpolation_logs = {}
extrapolation_logs = {}
for fun_name, fun_op in fun_dict.items():
if fun_name in ['sqr', 'sqrt']:
single_dim = True
in_dim = 1
else:
single_dim = False
in_dim = 2
data = create_data(*DATA_RANGE, N_ELTS, fun_op, fun_name, single_dim)
data_train, data_test, data_extra = split_data(data, less=LESS_THAN, greater=GREATER_THAN)
interpolation_logs[fun_name] = {}
extrapolation_logs[fun_name] = {}
for model_name, act in models.items():
if model_name == 'NAC':
model = snalu.StackedNAC(N_LAYERS, in_dim, OUT_DIM, HIDDEN_DIM)
elif model_name == 'NALU':
model = snalu.StackedNALU(N_LAYERS, in_dim, OUT_DIM, HIDDEN_DIM)
elif model_name == 'SNALU':
model = snalu.StackedSNALU(N_LAYERS, in_dim, OUT_DIM, HIDDEN_DIM)
else:
model = mlp.MLP(N_LAYERS, in_dim, OUT_DIM, HIDDEN_DIM, act)
train(model, data_train, N_EPOCHS, OPTIMIZER, LR)
_, mae_inter = test(model, data_test)
_, mae_extra = test(model, data_extra)
interpolation_logs[fun_name][model_name] = mae_inter
extrapolation_logs[fun_name][model_name] = mae_extra
print(f'{fun_name.ljust(10)}: {model_name.ljust(10)}: mae inter: {mae_inter}, mae extra: {mae_extra}')
del model
add : tanh : mae inter: 7.11, mae extra: 53.6 add : sigmoid : mae inter: 21.11, mae extra: 88.39 add : relu6 : mae inter: 0.0, mae extra: 26.46 add : softsign : mae inter: 17.05, mae extra: 74.52 add : selu : mae inter: 0.9, mae extra: 10.21 add : elu : mae inter: 1.0, mae extra: 28.27 add : relu : mae inter: 8.34, mae extra: 41.39 add : none : mae inter: 0.0, mae extra: 0.0 add : NAC : mae inter: 0.0, mae extra: 0.0 add : NALU : mae inter: 50.02, mae extra: 115.55 add : SNALU : mae inter: 0.0, mae extra: 0.0 sub : tanh : mae inter: 9.54, mae extra: 77.89 sub : sigmoid : mae inter: 24.22, mae extra: 120.5 sub : relu6 : mae inter: 14.42, mae extra: 74.67 sub : softsign : mae inter: 18.96, mae extra: 96.26 sub : selu : mae inter: 0.29, mae extra: 16.28 sub : elu : mae inter: 1.8, mae extra: 34.9 sub : relu : mae inter: 0.0, mae extra: 0.0 sub : none : mae inter: 0.0, mae extra: 0.0 sub : NAC : mae inter: 0.0, mae extra: 0.0 sub : NALU : mae inter: 26.27, mae extra: 58.73 sub : SNALU : mae inter: 0.0, mae extra: 0.0 mul : tanh : mae inter: 1875.04, mae extra: 14560.35 mul : sigmoid : mae inter: 1853.2, mae extra: 14594.65 mul : relu6 : mae inter: 1885.82, mae extra: 14456.46 mul : softsign : mae inter: 1875.17, mae extra: 14560.83 mul : selu : mae inter: 1437.63, mae extra: 11265.12 mul : elu : mae inter: 1631.57, mae extra: 12512.69 mul : relu : mae inter: 1058.95, mae extra: 9915.09 mul : none : mae inter: 1875.3, mae extra: 14587.53 mul : NAC : mae inter: 1894.42, mae extra: 14582.45 mul : NALU : mae inter: 1876.45, mae extra: 14630.59 mul : SNALU : mae inter: 0.0, mae extra: 0.02 div : tanh : mae inter: 2.62, mae extra: 2.9 div : sigmoid : mae inter: 2.62, mae extra: 2.9 div : relu6 : mae inter: 2.55, mae extra: 2.87 div : softsign : mae inter: 3.17, mae extra: 3.02 div : selu : mae inter: 2.67, mae extra: 3.83 div : elu : mae inter: 2.81, mae extra: 4.52 div : relu : mae inter: 2.57, mae extra: 3.83 div : none : mae inter: 2.83, mae extra: 2.85 div : NAC : mae inter: 2.87, mae extra: 2.9 div : NALU : mae inter: 2.75, mae extra: 3.71 div : SNALU : mae inter: 0.0, mae extra: 0.0 sqr : tanh : mae inter: 3091.11, mae extra: 23239.66 sqr : sigmoid : mae inter: 3073.51, mae extra: 23208.77 sqr : relu6 : mae inter: 2818.81, mae extra: 22954.94 sqr : softsign : mae inter: 3096.14, mae extra: 23242.83 sqr : selu : mae inter: 428.63, mae extra: 8812.06 sqr : elu : mae inter: 441.3, mae extra: 9091.07 sqr : relu : mae inter: 1473.83, mae extra: 15411.87 sqr : none : mae inter: 2497.52, mae extra: 20020.89 sqr : NAC : mae inter: 3190.51, mae extra: 23372.03 sqr : NALU : mae inter: 3192.97, mae extra: 23320.59 sqr : SNALU : mae inter: 0.0, mae extra: 0.01 sqrt : tanh : mae inter: 0.06, mae extra: 1.04 sqrt : sigmoid : mae inter: 0.06, mae extra: 1.12 sqrt : relu6 : mae inter: 0.4, mae extra: 2.6 sqrt : softsign : mae inter: 0.05, mae extra: 0.52 sqrt : selu : mae inter: 0.12, mae extra: 1.22 sqrt : elu : mae inter: 0.13, mae extra: 1.23 sqrt : relu : mae inter: 0.15, mae extra: 1.32 sqrt : none : mae inter: 0.46, mae extra: 2.37 sqrt : NAC : mae inter: 1.19, mae extra: 5.67 sqrt : NALU : mae inter: 0.17, mae extra: 0.93 sqrt : SNALU : mae inter: 0.02, mae extra: 0.19
data = create_data(*DATA_RANGE, N_ELTS, fun_dict['mul'], 'mul', single_dim=False)
data_train, data_test, data_extra = split_data(data, less=LESS_THAN, greater=GREATER_THAN)
model = snalu.StackedSNALU(N_LAYERS, in_dim=2, out_dim=OUT_DIM, hidden_dim=HIDDEN_DIM)
#train(model, data_train, N_EPOCHS, OPTIMIZER, lr=0.0001, verbose=True)
#print(test(model, data_test))
#print(test(model, data_extra))
def autolabel(rects, ax):
for rect in rects:
height = rect.get_height()
ax.text(rect.get_x() + rect.get_width() / 2., 0.9 * height,
str(height), ha='center', va='bottom')
idx = 1
n_rows = len(interpolation_logs.keys())
figure = plt.figure(figsize=(20, 40))
for fun_name in interpolation_logs.keys():
ax1 = figure.add_subplot(n_rows, 2, idx)
items = list(interpolation_logs[fun_name].keys())
y_pos = np.arange(len(items))
mae = list(interpolation_logs[fun_name].values())
rect1 = ax1.bar(y_pos, mae, align='center', alpha=0.5)
ax1.set_xticks(np.arange(len(items)))
ax1.set_xticklabels(items)
ax1.set_ylabel('mae')
ax1.set_title(f'{fun_name} (interpolation)')
autolabel(rect1, ax1)
ax2 = figure.add_subplot(n_rows, 2, idx + 1)
items = list(extrapolation_logs[fun_name].keys())
y_pos = np.arange(len(items))
mae = list(extrapolation_logs[fun_name].values())
rect2 = ax2.bar(y_pos, mae, align='center', alpha=0.5)
ax2.set_xticks(np.arange(len(items)))
ax2.set_xticklabels(items)
ax2.set_ylabel('mae')
ax2.set_title(f'{fun_name} (extrapolation)')
autolabel(rect2, ax2)
idx += 2
#plt.show()
plt.savefig('images/results.png')