#!/usr/bin/env python # coding: utf-8 # # Linear Regression - MLE # # Implementation of the classic linear regression model. Weights are fitted with Maximum Likelihood Estimation using [PyTorch distributions](https://pytorch.org/docs/stable/distributions.html). # # $y \sim \mathcal{N}(\alpha x + \beta, \sigma)$ # # # ## Setting up the environment # In[1]: import os import pandas as pd import numpy as np import matplotlib.pyplot as plt import seaborn as sns import torch import torch.nn.functional as F from sklearn.metrics import r2_score from sklearn.model_selection import train_test_split cwd = os.getcwd() if cwd.endswith('notebook'): os.chdir('..') cwd = os.getcwd() # In[2]: sns.set(palette='colorblind', font_scale=1.3) palette = sns.color_palette() # In[3]: seed = 444 np.random.seed(seed); torch.manual_seed(seed); # ## Generate dataset # In[4]: α_actual = 2.6 β_actual = 3.3 σ_actual = 0.7 # In[5]: def generate_samples(α, β, σ, min_x=-1, max_x=1, n_samples=500): x = np.linspace(min_x, max_x, n_samples)[:, np.newaxis] y = α * x + β dist = torch.distributions.Normal(torch.from_numpy(y), σ) y_sample = dist.sample() return x, y, y_sample.detach().numpy() # In[6]: def plot_line(x, y, y_sample): f, ax = plt.subplots(1, 1, figsize=(12, 6)) ax.plot(x.flatten(), y.flatten(), '-', color=palette[2], linewidth=3) ax.scatter(x.flatten(), y_sample.flatten(), color=palette[0], alpha=0.8) ax.set_xlabel('input x') ax.set_ylabel('output y') ax.set_title(r'$y \sim N(\alpha x + \beta, \sigma)$') return f, ax # In[7]: x, y, y_sample = generate_samples(α_actual, β_actual, σ_actual) # In[8]: f, _ = plot_line(x, y, y_sample); # In[9]: x.shape # ## Define model # In[10]: class LinearNormal(torch.nn.Module): def __init__(self): super().__init__() self.α = torch.nn.Parameter(torch.randn(())) self.β = torch.nn.Parameter(torch.randn(())) self.s = torch.nn.Parameter(torch.randn(())) @property def sigma(self): return F.softplus(self.s) # ensure σ > 0 def forward(self, x): m = self.α * x + self.β σ = self.sigma return torch.distributions.Normal(m, σ) # In[11]: def compute_loss(model, x, y): out_dist = model(x) neg_log_likelihood = -out_dist.log_prob(y) return torch.mean(neg_log_likelihood) def compute_rmse(model, x_test, y_test): model.eval() pred = model(x_test).sample() return torch.sqrt(torch.mean((pred - y_test)**2)) def predict(model, x): model.eval() out_dist = model(x) return out_dist.mean, out_dist.stddev, out_dist # ## Fit linear model # In[12]: def train_one_step(model, optimizer, x_batch, y_batch): model.train() optimizer.zero_grad() loss = compute_loss(model, x_batch, y_batch) loss.backward() optimizer.step() return loss # In[13]: def train(model, optimizer, x_train, x_val, y_train, y_val, n_epochs, batch_size=64, print_every=10): train_losses, val_losses = [], [] for epoch in range(n_epochs): batch_indices = sample_batch_indices(x_train, y_train, batch_size) batch_losses_t, batch_losses_v, batch_rmse_v = [], [], [] for batch_ix in batch_indices: b_train_loss = train_one_step(model, optimizer, x_train[batch_ix], y_train[batch_ix]) model.eval() b_val_loss = compute_loss(model, x_val, y_val) b_val_rmse = compute_rmse(model, x_val, y_val) batch_losses_t.append(b_train_loss.detach().numpy()) batch_losses_v.append(b_val_loss.detach().numpy()) batch_rmse_v.append(b_val_rmse.detach().numpy()) train_loss = np.mean(batch_losses_t) val_loss = np.mean(batch_losses_v) val_rmse = np.mean(batch_rmse_v) train_losses.append(train_loss) val_losses.append(val_loss) if epoch == 0 or (epoch + 1) % print_every == 0: print(f'Epoch {epoch+1} | Validation loss = {val_loss:.4f} | Validation RMSE = {val_rmse:.4f}') _, ax = plt.subplots(1, 1, figsize=(12, 6)) ax.plot(range(1, n_epochs + 1), train_losses, label='Train loss') ax.plot(range(1, n_epochs + 1), val_losses, label='Validation loss') ax.set_xlabel('Epoch') ax.set_ylabel('Loss') ax.set_title('Training Overview') ax.legend() return train_losses, val_losses def sample_batch_indices(x, y, batch_size, rs=None): if rs is None: rs = np.random.RandomState() train_ix = np.arange(len(x)) rs.shuffle(train_ix) n_batches = int(np.ceil(len(x) / batch_size)) batch_indices = [] for i in range(n_batches): start = i * batch_size end = start + batch_size batch_indices.append( train_ix[start:end].tolist() ) return batch_indices # In[14]: def compute_train_test_split(x, y, test_size): x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=test_size) return ( torch.from_numpy(x_train), torch.from_numpy(x_test), torch.from_numpy(y_train), torch.from_numpy(y_test), ) # In[15]: model = LinearNormal() optimizer = torch.optim.SGD(model.parameters(), lr=5e-3) # In[16]: x_train, x_test, y_train, y_test = compute_train_test_split(x, y_sample, test_size=0.2) # In[17]: train(model, optimizer, x_train, x_test, y_train, y_test, n_epochs=2000, print_every=200); # ## Validation # In[18]: y_pred, _, y_dist = predict(model, x_test) y_pred_sample = y_dist.sample() # In[19]: val_rmse = float(compute_rmse(model, x_test, y_test).detach().numpy()) print(f'Validation RMSE = {val_rmse}') # In[20]: val_r2 = r2_score(y_test.detach().numpy(), y_pred_sample.detach().numpy()) print(f'Validation R squared = {val_r2}') # ## Plot results # In[21]: def plot_results(x, y, y_sample, y_pred, std): f, ax = plt.subplots(1, 1, figsize=(12, 6)) x_arg_sort = x.flatten().argsort() xx = x.flatten()[x_arg_sort] v = y_pred.flatten()[x_arg_sort] std_val = std.flatten()[0] v_err_min = v - 2 * std_val v_err_max = v + 2 * std_val ax.fill_between( xx, v_err_min , v_err_max, color=palette[1], alpha=0.2, label='2$\sigma$ error', ) ax.plot(xx, y.flatten()[x_arg_sort], '-', color=palette[0], linewidth=2, label='Actual') ax.scatter(xx, y_sample.flatten()[x_arg_sort], color=palette[0]) ax.plot(xx, v, '-', color=palette[1], linewidth=2, label='Predicted') ax.set_title('Predictions on the validation set') ax.set_xlabel('input x') ax.set_ylabel('output y') ax.legend() return f, ax # In[22]: f, ax = plot_results( x_test.detach().numpy(), α_actual * x_test.detach().numpy() + β_actual, y_test.detach().numpy(), y_pred.detach().numpy(), y_dist.stddev.detach().numpy(), ) # In[23]: α_hat = float(model.α.detach().numpy()) β_hat = float(model.β.detach().numpy()) σ_hat = float(y_dist.stddev.detach().numpy()[0]) print(f'Actual α = {α_actual:.1f} | Predicted α = {α_hat:.1f}') print(f'Actual β = {β_actual:.1f} | Predicted β = {β_hat:.1f}') print(f'Actual σ = {σ_actual:.1f} | Predicted σ = {σ_hat:.1f}') # In[ ]: