Implementation of a linear regression model. Weights are estimated with Stochastic Variable Inference (SVI) using PyTorch distributions.
$y \sim \mathcal{N}(\alpha x + \beta, \sigma)$
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from sklearn.metrics import r2_score
from sklearn.model_selection import train_test_split
cwd = os.getcwd()
if cwd.endswith('notebook'):
os.chdir('..')
cwd = os.getcwd()
sns.set(palette='colorblind', font_scale=1.3)
palette = sns.color_palette()
class LinearNormal(torch.nn.Module):
def __init__(self):
super().__init__()
self.prior = torch.distributions.MultivariateNormal(
torch.zeros((3,)),
torch.Tensor([10, 10, 1]) * torch.eye(3),
)
self.mean = torch.nn.Parameter(torch.randn((3,)))
self.sigma = torch.nn.Parameter(torch.randn((3,)))
def forward(self, x, sample=True):
σ = torch.log(1 + torch.exp(self.sigma))
dist = torch.distributions.MultivariateNormal(
self.mean,
σ * torch.eye(3),
)
if sample:
alpha, beta, sigma = dist.rsample()
else:
alpha, beta, sigma = dist.mean
m = alpha * x + beta
s = torch.log(1 + torch.exp(sigma))
y_hat = torch.distributions.Normal(m, s)
kl_divergence = torch.distributions.kl.kl_divergence(dist, self.prior)
return y_hat, kl_divergence
def compute_loss(model, x, y):
y_hat, kl_divergence = model(x)
return torch.mean(-y_hat.log_prob(y))
def train_one_step(model, optimizer, x_batch, y_batch):
model.train()
optimizer.zero_grad()
loss = compute_loss(model, x_batch, y_batch)
loss.backward()
optimizer.step()
return loss
def compute_rmse(model, x_test, y_test):
model.eval()
y_hat, _ = model(x_test, sample=False)
pred = y_hat.sample()
return torch.sqrt(torch.mean((pred - y_test)**2))
α_actual = 2.6
β_actual = 3.3
σ_actual = 0.7
def generate_samples(α, β, σ, min_x=-1, max_x=1, n_samples=500):
x = np.linspace(min_x, max_x, n_samples)[:, np.newaxis]
y = α * x + β
dist = torch.distributions.Normal(torch.from_numpy(y), σ)
return x, y, dist.sample().detach().numpy()
def plot_line(x, y, y_sample):
f, ax = plt.subplots(1, 1, figsize=(12, 6))
ax.plot(x.flatten(), y.flatten(), '-', color=palette[0], linewidth=3)
ax.scatter(x.flatten(), y_sample.flatten(), color=palette[0], alpha=0.8)
ax.set_xlabel('input x')
ax.set_ylabel('output y')
return ax
x, y, y_sample = generate_samples(α_actual, β_actual, σ_actual)
plot_line(x, y, y_sample);
def train(model, optimizer, x_train, x_val, y_train, y_val, n_epochs, batch_size=64, print_every=10):
train_losses, val_losses = [], []
for epoch in range(n_epochs):
batch_indices = sample_batch_indices(x_train, y_train, batch_size)
batch_losses_t, batch_losses_v, batch_rmse_v = [], [], []
for batch_ix in batch_indices:
b_train_loss = train_one_step(model, optimizer, x_train[batch_ix], y_train[batch_ix])
model.eval()
b_val_loss = compute_loss(model, x_val, y_val)
b_val_rmse = compute_rmse(model, x_val, y_val)
batch_losses_t.append(b_train_loss.detach().numpy())
batch_losses_v.append(b_val_loss.detach().numpy())
batch_rmse_v.append(b_val_rmse.detach().numpy())
train_loss = np.mean(batch_losses_t)
val_loss = np.mean(batch_losses_v)
val_rmse = np.mean(batch_rmse_v)
train_losses.append(train_loss)
val_losses.append(val_loss)
if epoch == 0 or (epoch + 1) % print_every == 0:
print(f'Epoch {epoch+1} | Validation loss = {val_loss:.4f} | Validation RMSE = {val_rmse:.4f}')
_, ax = plt.subplots(1, 1, figsize=(12, 6))
ax.plot(range(1, n_epochs + 1), train_losses, label='Train loss')
ax.plot(range(1, n_epochs + 1), val_losses, label='Validation loss')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.set_title('Training Overview')
ax.legend()
return train_losses, val_losses
def sample_batch_indices(x, y, batch_size, rs=None):
if rs is None:
rs = np.random.RandomState()
train_ix = np.arange(len(x))
rs.shuffle(train_ix)
n_batches = int(np.ceil(len(x) / batch_size))
batch_indices = []
for i in range(n_batches):
start = i + batch_size
end = start + batch_size
batch_indices.append(
train_ix[start:end].tolist()
)
return batch_indices
def compute_train_test_split(x, y, test_size):
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=test_size)
return (
torch.from_numpy(x_train),
torch.from_numpy(x_test),
torch.from_numpy(y_train),
torch.from_numpy(y_test),
)
model = LinearNormal()
optimizer = torch.optim.SGD(model.parameters(), lr=5e-2, weight_decay=1e-5)
x_train, x_test, y_train, y_test = compute_train_test_split(x, y_sample, test_size=0.2)
train(model, optimizer, x_train, x_test, y_train, y_test, n_epochs=2000, print_every=100);
Epoch 1 | Validation loss = 2.9241 | Validation RMSE = 6.9634 Epoch 100 | Validation loss = 1.8937 | Validation RMSE = 1.8496 Epoch 200 | Validation loss = 1.2464 | Validation RMSE = 1.1442 Epoch 300 | Validation loss = 1.1188 | Validation RMSE = 1.1171 Epoch 400 | Validation loss = 1.1367 | Validation RMSE = 1.0144 Epoch 500 | Validation loss = 1.0793 | Validation RMSE = 1.0365 Epoch 600 | Validation loss = 1.0762 | Validation RMSE = 0.9811 Epoch 700 | Validation loss = 1.0844 | Validation RMSE = 1.0260 Epoch 800 | Validation loss = 1.0596 | Validation RMSE = 0.9475 Epoch 900 | Validation loss = 1.0768 | Validation RMSE = 1.0772 Epoch 1000 | Validation loss = 1.0654 | Validation RMSE = 1.0366 Epoch 1100 | Validation loss = 1.0628 | Validation RMSE = 0.9851 Epoch 1200 | Validation loss = 1.0592 | Validation RMSE = 0.9588 Epoch 1300 | Validation loss = 1.0843 | Validation RMSE = 1.0760 Epoch 1400 | Validation loss = 1.0590 | Validation RMSE = 1.0231 Epoch 1500 | Validation loss = 1.0792 | Validation RMSE = 0.9972 Epoch 1600 | Validation loss = 1.0645 | Validation RMSE = 0.9530 Epoch 1700 | Validation loss = 1.0616 | Validation RMSE = 1.0036 Epoch 1800 | Validation loss = 1.0590 | Validation RMSE = 1.0454 Epoch 1900 | Validation loss = 1.0707 | Validation RMSE = 1.0476 Epoch 2000 | Validation loss = 1.0516 | Validation RMSE = 0.9984
val_rmse = float(compute_rmse(model, x_test, y_test).detach().numpy())
print(f'Validation RMSE = {val_rmse}')
Validation RMSE = 1.0679740982747383
def plot_distributions(model):
_, axes = plt.subplots(1, 3, figsize=(18, 6))
axes = axes.flatten()
model.eval()
mean = model.mean.detach()
sigma = torch.log(1 + torch.exp(model.sigma)).detach()
for i, param in enumerate(['α', 'β', 'σ']):
ax = axes[i]
ax.set_title(f'{param}')
if param == 'σ':
m = torch.log(1 + torch.exp(mean[i]))
else:
m = mean[i]
s = sigma[i]
dist = torch.distributions.Normal(m, s)
sns.distplot(dist.sample(torch.Size((2000,))).numpy(), ax=ax)
plot_distributions(model)
/Users/srom/workspace/distributions/env/lib/python3.7/site-packages/seaborn/distributions.py:2557: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms). warnings.warn(msg, FutureWarning) /Users/srom/workspace/distributions/env/lib/python3.7/site-packages/seaborn/distributions.py:2557: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms). warnings.warn(msg, FutureWarning) /Users/srom/workspace/distributions/env/lib/python3.7/site-packages/seaborn/distributions.py:2557: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms). warnings.warn(msg, FutureWarning)
model.eval()
alpha, beta, sigma = model.mean
α_hat = float(alpha)
β_hat = float(beta)
σ_hat = float(torch.log(1 + torch.exp(sigma)))
print(f'Actual α = {α_actual:.2f} | Predicted α = {α_hat:.2f}')
print(f'Actual β = {β_actual:.2f} | Predicted β = {β_hat:.2f}')
print(f'Actual σ = {σ_actual:.2f} | Predicted σ = {σ_hat:.2f}')
Actual α = 2.60 | Predicted α = 2.57 Actual β = 3.30 | Predicted β = 3.34 Actual σ = 0.70 | Predicted σ = 0.75
y_hat, _ = model(x_test, sample=False)
def plot_results(x, y, y_sample, y_pred, y_pred_sample):
f, ax = plt.subplots(1, 1, figsize=(12, 6))
ax.plot(x.flatten(), y.flatten(), '-', color=palette[0], linewidth=2, label='Actual line')
ax.plot(x.flatten(), y_pred.flatten(), '-', color=palette[1], linewidth=2, label='Predicted line')
ax.scatter(x.flatten(), y_sample.flatten(), color=palette[0], label='Actual samples')
ax.scatter(x.flatten(), y_pred_sample.flatten(), color=palette[1], label='Predicted samples')
ax.set_xlabel('input x')
ax.set_ylabel('output y')
ax.legend()
return ax
plot_results(
x_test.detach().numpy(),
α_actual * x_test.detach().numpy() + β_actual,
y_test.detach().numpy(),
α_hat * x_test.detach().numpy() + β_hat,
y_hat.sample().detach().numpy()
);
def sample_parameters(model, n_samples=100):
model.eval()
mean = model.mean.detach()
sigma = torch.log(1 + torch.exp(model.sigma)).detach()
params = np.zeros((n_samples, 3))
for i, param in enumerate(['α', 'β', 'σ']):
if param == 'σ':
m = torch.log(1 + torch.exp(mean[i]))
else:
m = mean[i]
s = sigma[i]
dist = torch.distributions.Normal(m, s)
params[:,i] = dist.sample(torch.Size((n_samples,))).detach().numpy().flatten()
return params
params = sample_parameters(model, n_samples=1000)
params.shape
(1000, 3)
def make_predictions(params, x):
std_aleatoric = np.mean(params[:,2])
lines = np.zeros((len(x), len(params)))
for sample_number in range(len(params)):
lines[:,sample_number] = (params[sample_number, 0] * x + params[sample_number,1]).flatten()
predictions = np.mean(lines, axis=1)
std_epistemic = np.std(lines, axis=1)
std = std_epistemic + std_aleatoric
return predictions, std
def plot_uncertainty(x, y, y_pred, y_err):
f, ax = plt.subplots(1, 1, figsize=(12, 6))
ax.plot(x.flatten(), y.flatten(), '-', color=palette[0], linewidth=2, label='Actual line')
ax.plot(x.flatten(), y_pred.flatten(), '-', color=palette[1], linewidth=2, label='Predicted line')
ax.fill_between(
x.flatten(),
y_pred.flatten() + 2 * y_err,
y_pred.flatten() - 2 * y_err,
color=palette[1],
alpha=0.3,
)
ax.set_xlabel('input x')
ax.set_ylabel('output y')
ax.legend()
return ax
min_ = 10000
max_ = min_ + 1000
x_new, y_new, _ = generate_samples(α_actual, β_actual, σ_actual, min_x=min_, max_x=max_)
predictions, std = make_predictions(params, x_new)
print(np.mean(std))
plot_uncertainty(x_new, y_new, predictions, std);
58.40133088282731
min_ = 100000
max_ = min_ + 1000
x_new, y_new, _ = generate_samples(α_actual, β_actual, σ_actual, min_x=min_, max_x=max_)
predictions, std = make_predictions(params, x_new)
print(np.mean(std))
plot_uncertainty(x_new, y_new, predictions, std);
552.5717687488561