Implementation of a linear regression model. Weights are estimated with Monte Carlo Markov Chain (MCMC) sampling using PyTorch distributions & Pyro.
$y \sim \mathcal{N}(\alpha x + \beta, \sigma)$
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import pyro
import pyro.distributions as dist
from pyro.infer import MCMC, NUTS
from sklearn.metrics import r2_score
from sklearn.model_selection import train_test_split
cwd = os.getcwd()
if cwd.endswith('notebook'):
os.chdir('..')
cwd = os.getcwd()
sns.set(palette='colorblind', font_scale=1.3)
palette = sns.color_palette()
def linear_model(x_obs, y_obs):
α = pyro.sample('α', dist.Normal(0., 10.))
β = pyro.sample('β', dist.Normal(0., 10.))
σ = pyro.sample('σ', dist.Uniform(0., 10.))
m = α * x_obs + β
with pyro.plate('data', len(x_obs)):
pyro.sample('obs', dist.Normal(m, σ), obs=y_obs)
α_actual = 2.6
β_actual = 3.3
σ_actual = 0.7
def generate_samples(α, β, σ, min_x=-1, max_x=1, n_samples=500):
x = np.linspace(min_x, max_x, n_samples)[:, np.newaxis]
y = α * x + β
dist = torch.distributions.Normal(torch.from_numpy(y), σ)
return x, y, dist.sample().detach().numpy()
def plot_line(x, y, y_sample):
f, ax = plt.subplots(1, 1, figsize=(12, 6))
ax.plot(x.flatten(), y.flatten(), '-', color=palette[0], linewidth=3)
ax.scatter(x.flatten(), y_sample.flatten(), color=palette[0], alpha=0.8)
ax.set_xlabel('input x')
ax.set_ylabel('output y')
return ax
x, y, y_sample = generate_samples(α_actual, β_actual, σ_actual)
plot_line(x, y, y_sample);
def compute_train_test_split(x, y, test_size):
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=test_size)
return (
torch.from_numpy(x_train),
torch.from_numpy(x_test),
torch.from_numpy(y_train),
torch.from_numpy(y_test),
)
x_train, x_test, y_train, y_test = compute_train_test_split(x, y_sample, test_size=0.2)
# Utility function to print latent sites' quantile information.
def summary(samples):
site_stats = {}
for site_name, values in samples.items():
marginal_site = pd.DataFrame(values)
describe = marginal_site.describe(percentiles=[.05, 0.25, 0.5, 0.75, 0.95]).transpose()
site_stats[site_name] = describe[["mean", "std", "5%", "25%", "50%", "75%", "95%"]]
return site_stats
nuts_kernel = NUTS(linear_model)
mcmc = MCMC(nuts_kernel, num_samples=2000, warmup_steps=2000)
mcmc.run(x_train, y_train)
Sample: 100%|██████████| 4000/4000 [01:25, 46.89it/s, step size=5.40e-01, acc. prob=0.915]
hmc_samples = {k: v.detach().cpu().numpy() for k, v in mcmc.get_samples().items()}
res = summary(hmc_samples)
for site, values in res.items():
print(site)
print(values, "\n")
α mean std 5% 25% 50% 75% 95% 0 2.577201 0.002842 2.57246 2.575314 2.577224 2.579093 2.581918 β mean std 5% 25% 50% 75% 95% 0 3.232284 0.001714 3.229532 3.231135 3.232261 3.233455 3.235203 σ mean std 5% 25% 50% 75% 95% 0 0.693435 0.001267 0.691422 0.692532 0.693427 0.694282 0.695575
def plot_parameter_distributions(hmc_samples, res):
_, axes = plt.subplots(1, 3, figsize=(18, 6))
axes = axes.flatten()
for i, (site, _) in enumerate(res.items()):
ax = axes[i]
ax.set_title(f'{site}')
sns.distplot(hmc_samples[site], ax=ax)
plot_parameter_distributions(hmc_samples, res)
/Users/srom/workspace/distributions/env/lib/python3.7/site-packages/seaborn/distributions.py:2557: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms). warnings.warn(msg, FutureWarning) /Users/srom/workspace/distributions/env/lib/python3.7/site-packages/seaborn/distributions.py:2557: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms). warnings.warn(msg, FutureWarning) /Users/srom/workspace/distributions/env/lib/python3.7/site-packages/seaborn/distributions.py:2557: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms). warnings.warn(msg, FutureWarning)
def compute_rmse(y_actual, y_prediction):
return torch.sqrt(torch.mean((y_actual - y_prediction)**2))
α_hat = res['α']['mean'].iloc[0]
β_hat = res['β']['mean'].iloc[0]
σ_hat = res['σ']['mean'].iloc[0]
print(f'Actual α = {α_actual:.2f} | Predicted α = {α_hat:.2f}')
print(f'Actual β = {β_actual:.2f} | Predicted β = {β_hat:.2f}')
print(f'Actual σ = {σ_actual:.2f} | Predicted σ = {σ_hat:.2f}')
Actual α = 2.60 | Predicted α = 2.58 Actual β = 3.30 | Predicted β = 3.23 Actual σ = 0.70 | Predicted σ = 0.69
def predict(a, b, s, x):
return torch.distributions.Normal(a * x + b, s).sample()
y_hat = predict(α_hat, β_hat, σ_hat, x_test)
val_rmse = float(compute_rmse(y_test, y_hat).detach().numpy())
print(f'Validation RMSE = {val_rmse}')
Validation RMSE = 0.8888827354133352
def plot_results(x, y, y_sample, y_pred, y_pred_sample):
f, ax = plt.subplots(1, 1, figsize=(12, 6))
ax.plot(x.flatten(), y.flatten(), '-', color=palette[0], linewidth=2, label='Actual line')
ax.plot(x.flatten(), y_pred.flatten(), '-', color=palette[1], linewidth=2, label='Predicted line')
ax.scatter(x.flatten(), y_sample.flatten(), color=palette[0], label='Actual samples')
ax.scatter(x.flatten(), y_pred_sample.flatten(), color=palette[1], label='Predicted samples')
ax.set_xlabel('input x')
ax.set_ylabel('output y')
ax.legend()
return ax
plot_results(
x_test.detach().numpy(),
α_actual * x_test.detach().numpy() + β_actual,
y_test.detach().numpy(),
α_hat * x_test.detach().numpy() + β_hat,
y_hat.detach().numpy()
);