Implementation of Stochastic Variational Inference (SVI) using PyTorch, for the purpose of uncertainty quantification.
We'll consider the OLS Regression Challenge, which aims at predicting cancer mortality rates for US counties.
Notes:
import os
from os.path import join
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn.functional as F
from sklearn.metrics import r2_score
from sklearn.model_selection import train_test_split
from scipy.stats import binned_statistic
cwd = os.getcwd()
if cwd.endswith('notebook'):
os.chdir('..')
cwd = os.getcwd()
sns.set(palette='colorblind', font_scale=1.3)
palette = sns.color_palette()
seed = 444
np.random.seed(seed);
torch.manual_seed(seed);
torch.set_default_dtype(torch.float64)
df = pd.read_csv(join(cwd, 'data/cancer_reg.csv'))
df.head()
avgAnnCount | avgDeathsPerYear | TARGET_deathRate | incidenceRate | medIncome | popEst2015 | povertyPercent | studyPerCap | binnedInc | MedianAge | ... | PctPrivateCoverageAlone | PctEmpPrivCoverage | PctPublicCoverage | PctPublicCoverageAlone | PctWhite | PctBlack | PctAsian | PctOtherRace | PctMarriedHouseholds | BirthRate | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 1397.0 | 469 | 164.9 | 489.8 | 61898 | 260131 | 11.2 | 499.748204 | (61494.5, 125635] | 39.3 | ... | NaN | 41.6 | 32.9 | 14.0 | 81.780529 | 2.594728 | 4.821857 | 1.843479 | 52.856076 | 6.118831 |
1 | 173.0 | 70 | 161.3 | 411.6 | 48127 | 43269 | 18.6 | 23.111234 | (48021.6, 51046.4] | 33.0 | ... | 53.8 | 43.6 | 31.1 | 15.3 | 89.228509 | 0.969102 | 2.246233 | 3.741352 | 45.372500 | 4.333096 |
2 | 102.0 | 50 | 174.7 | 349.7 | 49348 | 21026 | 14.6 | 47.560164 | (48021.6, 51046.4] | 45.0 | ... | 43.5 | 34.9 | 42.1 | 21.1 | 90.922190 | 0.739673 | 0.465898 | 2.747358 | 54.444868 | 3.729488 |
3 | 427.0 | 202 | 194.8 | 430.4 | 44243 | 75882 | 17.1 | 342.637253 | (42724.4, 45201] | 42.8 | ... | 40.3 | 35.0 | 45.3 | 25.0 | 91.744686 | 0.782626 | 1.161359 | 1.362643 | 51.021514 | 4.603841 |
4 | 57.0 | 26 | 144.4 | 350.1 | 49955 | 10321 | 12.5 | 0.000000 | (48021.6, 51046.4] | 48.3 | ... | 43.9 | 35.1 | 44.0 | 22.7 | 94.104024 | 0.270192 | 0.665830 | 0.492135 | 54.027460 | 6.796657 |
5 rows × 34 columns
df.describe()
avgAnnCount | avgDeathsPerYear | TARGET_deathRate | incidenceRate | medIncome | popEst2015 | povertyPercent | studyPerCap | MedianAge | MedianAgeMale | ... | PctPrivateCoverageAlone | PctEmpPrivCoverage | PctPublicCoverage | PctPublicCoverageAlone | PctWhite | PctBlack | PctAsian | PctOtherRace | PctMarriedHouseholds | BirthRate | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
count | 3047.000000 | 3047.000000 | 3047.000000 | 3047.000000 | 3047.000000 | 3.047000e+03 | 3047.000000 | 3047.000000 | 3047.000000 | 3047.000000 | ... | 2438.000000 | 3047.000000 | 3047.000000 | 3047.000000 | 3047.000000 | 3047.000000 | 3047.000000 | 3047.000000 | 3047.000000 | 3047.000000 |
mean | 606.338544 | 185.965868 | 178.664063 | 448.268586 | 47063.281917 | 1.026374e+05 | 16.878175 | 155.399415 | 45.272333 | 39.570725 | ... | 48.453774 | 41.196324 | 36.252642 | 19.240072 | 83.645286 | 9.107978 | 1.253965 | 1.983523 | 51.243872 | 5.640306 |
std | 1416.356223 | 504.134286 | 27.751511 | 54.560733 | 12040.090836 | 3.290592e+05 | 6.409087 | 529.628366 | 45.304480 | 5.226017 | ... | 10.083006 | 9.447687 | 7.841741 | 6.113041 | 16.380025 | 14.534538 | 2.610276 | 3.517710 | 6.572814 | 1.985816 |
min | 6.000000 | 3.000000 | 59.700000 | 201.300000 | 22640.000000 | 8.270000e+02 | 3.200000 | 0.000000 | 22.300000 | 22.400000 | ... | 15.700000 | 13.500000 | 11.200000 | 2.600000 | 10.199155 | 0.000000 | 0.000000 | 0.000000 | 22.992490 | 0.000000 |
25% | 76.000000 | 28.000000 | 161.200000 | 420.300000 | 38882.500000 | 1.168400e+04 | 12.150000 | 0.000000 | 37.700000 | 36.350000 | ... | 41.000000 | 34.500000 | 30.900000 | 14.850000 | 77.296180 | 0.620675 | 0.254199 | 0.295172 | 47.763063 | 4.521419 |
50% | 171.000000 | 61.000000 | 178.100000 | 453.549422 | 45207.000000 | 2.664300e+04 | 15.900000 | 0.000000 | 41.000000 | 39.600000 | ... | 48.700000 | 41.100000 | 36.300000 | 18.800000 | 90.059774 | 2.247576 | 0.549812 | 0.826185 | 51.669941 | 5.381478 |
75% | 518.000000 | 149.000000 | 195.200000 | 480.850000 | 52492.000000 | 6.867100e+04 | 20.400000 | 83.650776 | 44.000000 | 42.500000 | ... | 55.600000 | 47.700000 | 41.550000 | 23.100000 | 95.451693 | 10.509732 | 1.221037 | 2.177960 | 55.395132 | 6.493677 |
max | 38150.000000 | 14010.000000 | 362.800000 | 1206.900000 | 125635.000000 | 1.017029e+07 | 47.400000 | 9762.308998 | 624.000000 | 64.700000 | ... | 78.900000 | 70.700000 | 65.100000 | 46.600000 | 100.000000 | 85.947799 | 42.619425 | 41.930251 | 78.075397 | 21.326165 |
8 rows × 32 columns
_, ax = plt.subplots(1, 1, figsize=(10, 5))
df['TARGET_deathRate'].hist(bins=50, ax=ax);
ax.set_title('Distribution of cancer death rate per 100,000 people');
ax.set_xlabel('Cancer death rate in county (per 100,000 people)');
ax.set_ylabel('Count');
target = 'TARGET_deathRate'
features = [
col for col in df.columns
if col not in [
target,
'Geography', # Label describing the county - each row has a different one
'binnedInc', # Redundant with median income?
'PctSomeCol18_24', # contains null values - ignoring for now
'PctEmployed16_Over', # contains null values - ignoring for now
'PctPrivateCoverageAlone', # contains null values - ignoring for now
]
]
print(len(features), 'features')
28 features
x = df[features].values
y = df[[target]].values
print(x.shape, y.shape)
(3047, 28) (3047, 1)
class VariationalModel(torch.nn.Module):
def __init__(
self,
n_inputs,
n_hidden,
priors,
x_scaler,
y_scaler,
n_shared_layers=1,
):
super().__init__()
self.priors = priors
self.x_scaler = x_scaler
self.y_scaler = y_scaler
self.jitter = 1e-6
# Shared layers
self.shared_layers = [
torch.nn.Linear(n_inputs, n_hidden)
]
if n_shared_layers > 1:
for _ in range(n_shared_layers - 1):
self.shared_layers.append(
torch.nn.Linear(n_hidden, n_hidden)
)
# Parameters for the mean
self.mean_loc_hidden = torch.nn.Linear(n_hidden, n_hidden)
self.mean_loc = torch.nn.Linear(n_hidden, 1)
self.mean_scale_hidden = torch.nn.Linear(n_hidden, n_hidden)
self.mean_scale = torch.nn.Linear(n_hidden, 1)
# Parameters for the standard deviation
self.std_loc_hidden = torch.nn.Linear(n_hidden, n_hidden)
self.std_loc = torch.nn.Linear(n_hidden, 1)
self.std_scale_hidden = torch.nn.Linear(n_hidden, n_hidden)
self.std_scale = torch.nn.Linear(n_hidden, 1)
def encoder(self, x):
shared = self.x_scaler(x)
for layer in self.shared_layers:
shared = layer(shared)
shared = F.relu(shared)
#shared = self.dropout(shared)
# Parametrization of the mean (normal)
mean_loc = self.mean_loc_hidden(shared)
mean_loc = F.relu(mean_loc)
mean_loc = self.mean_loc(mean_loc)
mean_scale = self.mean_scale_hidden(shared)
mean_scale = F.relu(mean_scale)
mean_scale = F.softplus(self.mean_scale(mean_scale)) + self.jitter
# Parametrization of the standard deviation (log normal)
std_loc = self.std_loc_hidden(shared)
std_loc = F.relu(std_loc)
std_loc = self.std_loc(std_loc)
std_scale = self.std_scale_hidden(shared)
std_scale = F.relu(std_scale)
std_scale = F.softplus(self.std_scale(std_scale)) + self.jitter
return (
torch.distributions.Normal(mean_loc, mean_scale),
torch.distributions.LogNormal(std_loc, std_scale)
)
def decoder(self, mean, std):
return torch.distributions.Normal(mean, std)
def forward(self, x, sample_shape=torch.Size([])):
mean_dist, std_dist = self.encoder(x)
mean = mean_dist.rsample(sample_shape)
std = std_dist.rsample(sample_shape) + self.jitter
y_hat = self.decoder(mean, std)
kl_divergence = None
if self.priors is not None:
kl_divergence1 = torch.distributions.kl.kl_divergence(mean_dist, self.priors[0])
kl_divergence2 = torch.distributions.kl.kl_divergence(std_dist, self.priors[1])
kl_divergence = kl_divergence1 + kl_divergence2
return y_hat, kl_divergence
def compute_loss(model, x, y, kl_reg=0.1):
"""
ELBO loss - https://en.wikipedia.org/wiki/Evidence_lower_bound
"""
y_scaled = model.y_scaler(y)
y_hat, kl_divergence = model(x)
neg_log_likelihood = -y_hat.log_prob(y_scaled)
if kl_divergence is not None:
neg_log_likelihood += kl_reg * kl_divergence
return torch.mean(neg_log_likelihood)
def compute_rmse(model, x_test, y_test):
model.eval()
y_hat, _ = model(x_test)
pred = model.y_scaler.inverse_transform(y_hat.mean)
return torch.sqrt(torch.mean((pred - y_test)**2))
def train_one_step(model, optimizer, x_batch, y_batch):
model.train()
optimizer.zero_grad()
loss = compute_loss(model, x_batch, y_batch)
loss.backward()
optimizer.step()
return loss
def train(model, optimizer, x_train, x_val, y_train, y_val, n_epochs, batch_size, scheduler=None, print_every=10):
train_losses, val_losses = [], []
for epoch in range(n_epochs):
batch_indices = sample_batch_indices(x_train, y_train, batch_size)
batch_losses_t, batch_losses_v, batch_rmse_v = [], [], []
for batch_ix in batch_indices:
b_train_loss = train_one_step(model, optimizer, x_train[batch_ix], y_train[batch_ix])
model.eval()
b_val_loss = compute_loss(model, x_val, y_val)
b_val_rmse = compute_rmse(model, x_val, y_val)
batch_losses_t.append(b_train_loss.detach().numpy())
batch_losses_v.append(b_val_loss.detach().numpy())
batch_rmse_v.append(b_val_rmse.detach().numpy())
if scheduler is not None:
scheduler.step()
train_loss = np.mean(batch_losses_t)
val_loss = np.mean(batch_losses_v)
val_rmse = np.mean(batch_rmse_v)
train_losses.append(train_loss)
val_losses.append(val_loss)
if epoch == 0 or (epoch + 1) % print_every == 0:
print(f'Epoch {epoch+1} | Validation loss = {val_loss:.4f} | Validation RMSE = {val_rmse:.4f}')
_, ax = plt.subplots(1, 1, figsize=(12, 6))
ax.plot(range(1, n_epochs + 1), train_losses, label='Train loss')
ax.plot(range(1, n_epochs + 1), val_losses, label='Validation loss')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.set_title('Training Overview')
ax.legend()
return train_losses, val_losses
def sample_batch_indices(x, y, batch_size, rs=None):
if rs is None:
rs = np.random.RandomState()
train_ix = np.arange(len(x))
rs.shuffle(train_ix)
n_batches = int(np.ceil(len(x) / batch_size))
batch_indices = []
for i in range(n_batches):
start = i * batch_size
end = start + batch_size
batch_indices.append(
train_ix[start:end].tolist()
)
return batch_indices
class StandardScaler(object):
"""
Standardize data by removing the mean and scaling to unit variance.
"""
def __init__(self):
self.mean = None
self.scale = None
def fit(self, sample):
self.mean = sample.mean(0, keepdim=True)
self.scale = sample.std(0, unbiased=False, keepdim=True)
return self
def __call__(self, sample):
return self.transform(sample)
def transform(self, sample):
return (sample - self.mean) / self.scale
def inverse_transform(self, sample):
return sample * self.scale + self.mean
def compute_train_test_split(x, y, test_size):
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=test_size)
return (
torch.from_numpy(x_train),
torch.from_numpy(x_test),
torch.from_numpy(y_train),
torch.from_numpy(y_test),
)
x_train, x_val, y_train, y_val = compute_train_test_split(x, y, test_size=0.2)
print(x_train.shape, y_train.shape)
print(x_val.shape, y_val.shape)
torch.Size([2437, 28]) torch.Size([2437, 1]) torch.Size([610, 28]) torch.Size([610, 1])
x_scaler = StandardScaler().fit(torch.from_numpy(x))
y_scaler = StandardScaler().fit(torch.from_numpy(y))
def lr_steep_schedule(epoch):
threshold = 5
if epoch < threshold:
return 1.0
multiplier = 2 / 10
step_change = (epoch - threshold) // 300
return multiplier * 0.5 ** (step_change)
def lr_simple_schedule(epoch):
return 0.5 ** (epoch // 200)
def make_priors(y_train, y_scaler):
std = torch.Tensor((0.5,))
scale = torch.Tensor((0.2,))
std_loc = torch.log(std**2 / torch.sqrt(std**2 + scale**2))
std_scale = torch.sqrt(torch.log(1 + scale**2 / std**2))
return [
torch.distributions.Normal(0, 1),
torch.distributions.LogNormal(std_loc, std_scale),
]
%%time
learning_rate = 1e-3
momentum = 0.9
weight_decay = 1e-4
n_epochs = 400
batch_size = 64
print_every = 50
n_hidden = 100
n_shared_layers = 1
priors = make_priors(y_train, y_scaler)
model = VariationalModel(
n_inputs=x.shape[1],
n_hidden=n_hidden,
priors=priors,
n_shared_layers=n_shared_layers,
x_scaler=x_scaler,
y_scaler=y_scaler,
)
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'{pytorch_total_params:,} trainable parameters')
print()
optimizer = torch.optim.SGD(
model.parameters(),
lr=learning_rate,
momentum=momentum,
nesterov=True,
weight_decay=weight_decay,
)
scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer,
lr_lambda=lr_steep_schedule,
)
train_losses, val_losses = train(
model,
optimizer,
x_train,
x_val,
y_train,
y_val,
n_epochs=n_epochs,
batch_size=batch_size,
scheduler=scheduler,
print_every=print_every,
)
40,804 trainable parameters Epoch 1 | Validation loss = 2.1849 | Validation RMSE = 31.4644 Epoch 50 | Validation loss = 1.3104 | Validation RMSE = 20.1894 Epoch 100 | Validation loss = 1.2640 | Validation RMSE = 19.5906 Epoch 150 | Validation loss = 1.2476 | Validation RMSE = 19.1869 Epoch 200 | Validation loss = 1.2440 | Validation RMSE = 19.0512 Epoch 250 | Validation loss = 1.2401 | Validation RMSE = 18.9143 Epoch 300 | Validation loss = 1.2316 | Validation RMSE = 18.9218 Epoch 350 | Validation loss = 1.2327 | Validation RMSE = 18.8604 Epoch 400 | Validation loss = 1.2312 | Validation RMSE = 18.8614 CPU times: user 2min 15s, sys: 25.4 s, total: 2min 41s Wall time: 2min 30s
y_dist, _ = model(x_val)
y_hat = model.y_scaler.inverse_transform(y_dist.mean)
val_rmse = float(compute_rmse(model, x_val, y_val).detach().numpy())
print(f'Validation RMSE = {val_rmse:.2f}')
Validation RMSE = 18.89
val_r2 = r2_score(
y_val.detach().numpy().flatten(),
y_hat.detach().numpy().flatten(),
)
print(f'Validation $R^2$ = {val_r2:.2f}')
Validation $R^2$ = 0.54
def plot_results(y_true, y_pred):
_, ax = plt.subplots(1, 1, figsize=(7, 7))
palette = sns.color_palette()
min_value = min(np.amin(y_true), np.amin(y_pred))
max_value = max(np.amax(y_true), np.amax(y_pred))
y_mid = np.linspace(min_value, max_value)
ax.plot(y_mid, y_mid, '--', color=palette[1])
ax.scatter(y_true, y_pred, color=palette[0], alpha=0.5);
return ax
ax = plot_results(
y_val.detach().numpy().flatten(),
y_hat.detach().numpy().flatten(),
);
ax.text(270, 120, f'$R^2 = {val_r2:.2f}$')
ax.text(270, 100, f'$RMSE = {val_rmse:.2f}$')
ax.set_xlabel('Actuals');
ax.set_ylabel('Predictions');
ax.set_title('Regression results on validation set');
def make_predictions(model, x):
mean_dist, std_dist = model.encoder(x)
inv_tr = model.y_scaler.inverse_transform
y_hat = inv_tr(mean_dist.mean)
# Recover standard deviation's original scale
std_from_mean_dist = inv_tr(mean_dist.mean + mean_dist.stddev) - y_hat
std_from_std_dist = inv_tr(mean_dist.mean + std_dist.mean + std_dist.stddev) - y_hat
return y_hat, std_from_mean_dist, std_from_std_dist
y_hat, std_from_mean_dist , std_from_std_dist = make_predictions(model, x_val)
std = std_from_mean_dist + std_from_std_dist
plt.hist(y_hat.detach().numpy().flatten(), bins=50);
plt.hist(std_from_mean_dist.detach().numpy().flatten(), bins=50);
plt.hist(std_from_std_dist.detach().numpy().flatten(), bins=50);
absolute_errors = torch.abs(y_hat - y_val).detach().numpy().flatten()
stds = std.detach().numpy().flatten()
df = pd.DataFrame.from_dict({
'error': absolute_errors,
'uncertainty': stds
})
df.corr()
error | uncertainty | |
---|---|---|
error | 1.000000 | 0.313831 |
uncertainty | 0.313831 | 1.000000 |
def plot_absolute_error_vs_uncertainty(y_val, y_hat, std, binned):
absolute_errors = torch.abs(y_hat - y_val).detach().numpy().flatten()
stds = std.detach().numpy().flatten()
if binned:
ret = binned_statistic(absolute_errors, stds, bins=100)
a = ret.bin_edges[1:]
b = ret.statistic
alpha = 1.0
else:
a = absolute_errors
b = stds
alpha = 0.3
_, ax = plt.subplots(1, 1, figsize=(8, 8))
ax.scatter(a, b, alpha=alpha)
max_val = max(ax.get_xlim()[-1], ax.get_ylim()[-1])
x = np.linspace(0, max_val)
ax.plot(x, x, '--', color='#ccc', alpha=0.8)
return ax
ax = plot_absolute_error_vs_uncertainty(y_val, y_hat, std, binned=True)
ax.set_title('Absolute error vs uncertainty (binned & averaged)');
ax.set_xlabel('Absolute error');
ax.set_ylabel('Standard deviations');
ax = plot_absolute_error_vs_uncertainty(y_val, y_hat, std, binned=False)
ax.set_title('Absolute error vs uncertainty');
ax.set_xlabel('Absolute error');
ax.set_ylabel('Standard deviations');
def plot_absolute_error_vs_uncertainty_normalized(y_val, y_hat, std, binned):
absolute_errors = (torch.abs(y_hat - y_val) / y_val).detach().numpy().flatten()
stds = (std / y_val).detach().numpy().flatten()
if binned:
ret = binned_statistic(absolute_errors, stds, bins=100)
a = ret.bin_edges[1:]
b = ret.statistic
alpha = 1.0
else:
a = absolute_errors
b = stds
alpha = 0.3
_, ax = plt.subplots(1, 1, figsize=(8, 8))
ax.scatter(a, b, alpha=alpha)
max_val = max(ax.get_xlim()[-1], ax.get_ylim()[-1])
x = np.linspace(0, max_val)
ax.plot(x, x, '--', color='#ccc', alpha=0.8)
return ax
ax = plot_absolute_error_vs_uncertainty_normalized(y_val, y_hat, std, binned=True)
ax.set_title('Absolute error vs uncertainty (normalized)');
ax.set_xlabel('Absolute error / actual');
ax.set_ylabel('Standard deviations / actual');
ax = plot_absolute_error_vs_uncertainty_normalized(y_val, y_hat, std, binned=False)
ax.set_title('Absolute error vs uncertainty');
ax.set_xlabel('Absolute error / actual');
ax.set_ylabel('Standard deviations / actual');
We'll consider an input composed of the maximum of all features.
random_input_np = np.zeros((1, x_val.shape[1]))
for i in range(x_val.shape[1]):
fn = torch.amax
random_input_np[0, i] = fn(x_val[:, i]) * 2
random_input = torch.from_numpy(random_input_np)
random_input.shape
torch.Size([1, 28])
_, std_epistemic, std_aleatoric = make_predictions(model, random_input)
uncertainty = float((std_epistemic + std_aleatoric).detach().numpy())
print(f'Uncertainty on made up input: {uncertainty:.2f}')
Uncertainty on made up input: 711.99
Uncertainty is high, which is what we want!
_, std_epistemic, std_aleatoric = make_predictions(model, random_input)
print(f'Aleatoric uncertainty: {float(std_aleatoric.detach().numpy()):.2f}')
print(f'Epistemic uncertainty: {float(std_epistemic.detach().numpy()):.2f}')
Aleatoric uncertainty: 711.99 Epistemic uncertainty: 0.00