Implementation of the classic linear regression model. Weights are fitted with Maximum Likelihood Estimation using PyTorch distributions.
$y \sim \mathcal{N}(\alpha x + \beta, \sigma)$
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn.functional as F
from sklearn.metrics import r2_score
from sklearn.model_selection import train_test_split
cwd = os.getcwd()
if cwd.endswith('notebook'):
os.chdir('..')
cwd = os.getcwd()
sns.set(palette='colorblind', font_scale=1.3)
palette = sns.color_palette()
seed = 444
np.random.seed(seed);
torch.manual_seed(seed);
α_actual = 2.6
β_actual = 3.3
σ_actual = 0.7
def generate_samples(α, β, σ, min_x=-1, max_x=1, n_samples=500):
x = np.linspace(min_x, max_x, n_samples)[:, np.newaxis]
y = α * x + β
dist = torch.distributions.Normal(torch.from_numpy(y), σ)
y_sample = dist.sample()
return x, y, y_sample.detach().numpy()
def plot_line(x, y, y_sample):
f, ax = plt.subplots(1, 1, figsize=(12, 6))
ax.plot(x.flatten(), y.flatten(), '-', color=palette[2], linewidth=3)
ax.scatter(x.flatten(), y_sample.flatten(), color=palette[0], alpha=0.8)
ax.set_xlabel('input x')
ax.set_ylabel('output y')
ax.set_title(r'$y \sim N(\alpha x + \beta, \sigma)$')
return f, ax
x, y, y_sample = generate_samples(α_actual, β_actual, σ_actual)
f, _ = plot_line(x, y, y_sample);
x.shape
(500, 1)
class LinearNormal(torch.nn.Module):
def __init__(self):
super().__init__()
self.α = torch.nn.Parameter(torch.randn(()))
self.β = torch.nn.Parameter(torch.randn(()))
self.s = torch.nn.Parameter(torch.randn(()))
@property
def sigma(self):
return F.softplus(self.s) # ensure σ > 0
def forward(self, x):
m = self.α * x + self.β
σ = self.sigma
return torch.distributions.Normal(m, σ)
def compute_loss(model, x, y):
out_dist = model(x)
neg_log_likelihood = -out_dist.log_prob(y)
return torch.mean(neg_log_likelihood)
def compute_rmse(model, x_test, y_test):
model.eval()
pred = model(x_test).sample()
return torch.sqrt(torch.mean((pred - y_test)**2))
def predict(model, x):
model.eval()
out_dist = model(x)
return out_dist.mean, out_dist.stddev, out_dist
def train_one_step(model, optimizer, x_batch, y_batch):
model.train()
optimizer.zero_grad()
loss = compute_loss(model, x_batch, y_batch)
loss.backward()
optimizer.step()
return loss
def train(model, optimizer, x_train, x_val, y_train, y_val, n_epochs, batch_size=64, print_every=10):
train_losses, val_losses = [], []
for epoch in range(n_epochs):
batch_indices = sample_batch_indices(x_train, y_train, batch_size)
batch_losses_t, batch_losses_v, batch_rmse_v = [], [], []
for batch_ix in batch_indices:
b_train_loss = train_one_step(model, optimizer, x_train[batch_ix], y_train[batch_ix])
model.eval()
b_val_loss = compute_loss(model, x_val, y_val)
b_val_rmse = compute_rmse(model, x_val, y_val)
batch_losses_t.append(b_train_loss.detach().numpy())
batch_losses_v.append(b_val_loss.detach().numpy())
batch_rmse_v.append(b_val_rmse.detach().numpy())
train_loss = np.mean(batch_losses_t)
val_loss = np.mean(batch_losses_v)
val_rmse = np.mean(batch_rmse_v)
train_losses.append(train_loss)
val_losses.append(val_loss)
if epoch == 0 or (epoch + 1) % print_every == 0:
print(f'Epoch {epoch+1} | Validation loss = {val_loss:.4f} | Validation RMSE = {val_rmse:.4f}')
_, ax = plt.subplots(1, 1, figsize=(12, 6))
ax.plot(range(1, n_epochs + 1), train_losses, label='Train loss')
ax.plot(range(1, n_epochs + 1), val_losses, label='Validation loss')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.set_title('Training Overview')
ax.legend()
return train_losses, val_losses
def sample_batch_indices(x, y, batch_size, rs=None):
if rs is None:
rs = np.random.RandomState()
train_ix = np.arange(len(x))
rs.shuffle(train_ix)
n_batches = int(np.ceil(len(x) / batch_size))
batch_indices = []
for i in range(n_batches):
start = i + batch_size
end = start + batch_size
batch_indices.append(
train_ix[start:end].tolist()
)
return batch_indices
def compute_train_test_split(x, y, test_size):
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=test_size)
return (
torch.from_numpy(x_train),
torch.from_numpy(x_test),
torch.from_numpy(y_train),
torch.from_numpy(y_test),
)
model = LinearNormal()
optimizer = torch.optim.SGD(model.parameters(), lr=5e-3)
x_train, x_test, y_train, y_test = compute_train_test_split(x, y_sample, test_size=0.2)
train(model, optimizer, x_train, x_test, y_train, y_test, n_epochs=2000, print_every=200);
Epoch 1 | Validation loss = 2.9909 | Validation RMSE = 3.8504 Epoch 200 | Validation loss = 2.0633 | Validation RMSE = 2.9213 Epoch 400 | Validation loss = 1.5048 | Validation RMSE = 1.6410 Epoch 600 | Validation loss = 1.1390 | Validation RMSE = 1.0149 Epoch 800 | Validation loss = 1.1402 | Validation RMSE = 1.0237 Epoch 1000 | Validation loss = 1.1451 | Validation RMSE = 0.9709 Epoch 1200 | Validation loss = 1.1471 | Validation RMSE = 1.0116 Epoch 1400 | Validation loss = 1.1419 | Validation RMSE = 1.0203 Epoch 1600 | Validation loss = 1.1433 | Validation RMSE = 1.0337 Epoch 1800 | Validation loss = 1.1448 | Validation RMSE = 1.0112 Epoch 2000 | Validation loss = 1.1406 | Validation RMSE = 1.0301
y_pred, _, y_dist = predict(model, x_test)
y_pred_sample = y_dist.sample()
val_rmse = float(compute_rmse(model, x_test, y_test).detach().numpy())
print(f'Validation RMSE = {val_rmse}')
Validation RMSE = 0.9420654411386103
val_r2 = r2_score(y_test.detach().numpy(), y_pred_sample.detach().numpy())
print(f'Validation R squared = {val_r2}')
Validation R squared = 0.6884557510131627
def plot_results(x, y, y_sample, y_pred, std):
f, ax = plt.subplots(1, 1, figsize=(12, 6))
x_arg_sort = x.flatten().argsort()
xx = x.flatten()[x_arg_sort]
v = y_pred.flatten()[x_arg_sort]
std_val = std.flatten()[0]
v_err_min = v - 2 * std_val
v_err_max = v + 2 * std_val
ax.fill_between(
xx,
v_err_min ,
v_err_max,
color=palette[1],
alpha=0.2,
label='2$\sigma$ error',
)
ax.plot(xx, y.flatten()[x_arg_sort], '-', color=palette[0], linewidth=2, label='Actual')
ax.scatter(xx, y_sample.flatten()[x_arg_sort], color=palette[0])
ax.plot(xx, v, '-', color=palette[1], linewidth=2, label='Predicted')
ax.set_title('Predictions on the validation set')
ax.set_xlabel('input x')
ax.set_ylabel('output y')
ax.legend()
return f, ax
f, ax = plot_results(
x_test.detach().numpy(),
α_actual * x_test.detach().numpy() + β_actual,
y_test.detach().numpy(),
y_pred.detach().numpy(),
y_dist.stddev.detach().numpy(),
)
α_hat = float(model.α.detach().numpy())
β_hat = float(model.β.detach().numpy())
σ_hat = float(y_dist.stddev.detach().numpy()[0])
print(f'Actual α = {α_actual:.1f} | Predicted α = {α_hat:.1f}')
print(f'Actual β = {β_actual:.1f} | Predicted β = {β_hat:.1f}')
print(f'Actual σ = {σ_actual:.1f} | Predicted σ = {σ_hat:.1f}')
Actual α = 2.6 | Predicted α = 2.6 Actual β = 3.3 | Predicted β = 3.3 Actual σ = 0.7 | Predicted σ = 0.7