Implementation of Stochastic Variational Inference (SVI) using PyTorch, for the purpose of uncertainty quantification.
We'll consider the OLS Regression Challenge, which aims at predicting cancer mortality rates for US counties.
Notes:
import os
from os.path import join
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from torch import nn
import torch.nn.functional as F
from sklearn.metrics import r2_score
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from scipy.stats import binned_statistic
cwd = os.getcwd()
if cwd.endswith('notebook'):
os.chdir('..')
cwd = os.getcwd()
sns.set(palette='colorblind', font_scale=1.3)
palette = sns.color_palette()
seed = 444
np.random.seed(seed);
torch.manual_seed(seed);
torch.set_default_dtype(torch.float64)
df = pd.read_csv(join(cwd, 'data/cancer_reg.csv'))
df.head()
avgAnnCount | avgDeathsPerYear | TARGET_deathRate | incidenceRate | medIncome | popEst2015 | povertyPercent | studyPerCap | binnedInc | MedianAge | ... | PctPrivateCoverageAlone | PctEmpPrivCoverage | PctPublicCoverage | PctPublicCoverageAlone | PctWhite | PctBlack | PctAsian | PctOtherRace | PctMarriedHouseholds | BirthRate | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 1397.0 | 469 | 164.9 | 489.8 | 61898 | 260131 | 11.2 | 499.748204 | (61494.5, 125635] | 39.3 | ... | NaN | 41.6 | 32.9 | 14.0 | 81.780529 | 2.594728 | 4.821857 | 1.843479 | 52.856076 | 6.118831 |
1 | 173.0 | 70 | 161.3 | 411.6 | 48127 | 43269 | 18.6 | 23.111234 | (48021.6, 51046.4] | 33.0 | ... | 53.8 | 43.6 | 31.1 | 15.3 | 89.228509 | 0.969102 | 2.246233 | 3.741352 | 45.372500 | 4.333096 |
2 | 102.0 | 50 | 174.7 | 349.7 | 49348 | 21026 | 14.6 | 47.560164 | (48021.6, 51046.4] | 45.0 | ... | 43.5 | 34.9 | 42.1 | 21.1 | 90.922190 | 0.739673 | 0.465898 | 2.747358 | 54.444868 | 3.729488 |
3 | 427.0 | 202 | 194.8 | 430.4 | 44243 | 75882 | 17.1 | 342.637253 | (42724.4, 45201] | 42.8 | ... | 40.3 | 35.0 | 45.3 | 25.0 | 91.744686 | 0.782626 | 1.161359 | 1.362643 | 51.021514 | 4.603841 |
4 | 57.0 | 26 | 144.4 | 350.1 | 49955 | 10321 | 12.5 | 0.000000 | (48021.6, 51046.4] | 48.3 | ... | 43.9 | 35.1 | 44.0 | 22.7 | 94.104024 | 0.270192 | 0.665830 | 0.492135 | 54.027460 | 6.796657 |
5 rows × 34 columns
df.describe()
avgAnnCount | avgDeathsPerYear | TARGET_deathRate | incidenceRate | medIncome | popEst2015 | povertyPercent | studyPerCap | MedianAge | MedianAgeMale | ... | PctPrivateCoverageAlone | PctEmpPrivCoverage | PctPublicCoverage | PctPublicCoverageAlone | PctWhite | PctBlack | PctAsian | PctOtherRace | PctMarriedHouseholds | BirthRate | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
count | 3047.000000 | 3047.000000 | 3047.000000 | 3047.000000 | 3047.000000 | 3.047000e+03 | 3047.000000 | 3047.000000 | 3047.000000 | 3047.000000 | ... | 2438.000000 | 3047.000000 | 3047.000000 | 3047.000000 | 3047.000000 | 3047.000000 | 3047.000000 | 3047.000000 | 3047.000000 | 3047.000000 |
mean | 606.338544 | 185.965868 | 178.664063 | 448.268586 | 47063.281917 | 1.026374e+05 | 16.878175 | 155.399415 | 45.272333 | 39.570725 | ... | 48.453774 | 41.196324 | 36.252642 | 19.240072 | 83.645286 | 9.107978 | 1.253965 | 1.983523 | 51.243872 | 5.640306 |
std | 1416.356223 | 504.134286 | 27.751511 | 54.560733 | 12040.090836 | 3.290592e+05 | 6.409087 | 529.628366 | 45.304480 | 5.226017 | ... | 10.083006 | 9.447687 | 7.841741 | 6.113041 | 16.380025 | 14.534538 | 2.610276 | 3.517710 | 6.572814 | 1.985816 |
min | 6.000000 | 3.000000 | 59.700000 | 201.300000 | 22640.000000 | 8.270000e+02 | 3.200000 | 0.000000 | 22.300000 | 22.400000 | ... | 15.700000 | 13.500000 | 11.200000 | 2.600000 | 10.199155 | 0.000000 | 0.000000 | 0.000000 | 22.992490 | 0.000000 |
25% | 76.000000 | 28.000000 | 161.200000 | 420.300000 | 38882.500000 | 1.168400e+04 | 12.150000 | 0.000000 | 37.700000 | 36.350000 | ... | 41.000000 | 34.500000 | 30.900000 | 14.850000 | 77.296180 | 0.620675 | 0.254199 | 0.295172 | 47.763063 | 4.521419 |
50% | 171.000000 | 61.000000 | 178.100000 | 453.549422 | 45207.000000 | 2.664300e+04 | 15.900000 | 0.000000 | 41.000000 | 39.600000 | ... | 48.700000 | 41.100000 | 36.300000 | 18.800000 | 90.059774 | 2.247576 | 0.549812 | 0.826185 | 51.669941 | 5.381478 |
75% | 518.000000 | 149.000000 | 195.200000 | 480.850000 | 52492.000000 | 6.867100e+04 | 20.400000 | 83.650776 | 44.000000 | 42.500000 | ... | 55.600000 | 47.700000 | 41.550000 | 23.100000 | 95.451693 | 10.509732 | 1.221037 | 2.177960 | 55.395132 | 6.493677 |
max | 38150.000000 | 14010.000000 | 362.800000 | 1206.900000 | 125635.000000 | 1.017029e+07 | 47.400000 | 9762.308998 | 624.000000 | 64.700000 | ... | 78.900000 | 70.700000 | 65.100000 | 46.600000 | 100.000000 | 85.947799 | 42.619425 | 41.930251 | 78.075397 | 21.326165 |
8 rows × 32 columns
_, ax = plt.subplots(1, 1, figsize=(10, 5))
df['TARGET_deathRate'].hist(bins=50, ax=ax);
ax.set_title('Distribution of cancer death rate per 100,000 people');
ax.set_xlabel('Cancer death rate in county (per 100,000 people)');
ax.set_ylabel('Count');
target = 'TARGET_deathRate'
features = [
col for col in df.columns
if col not in [
target,
'Geography', # Label describing the county - each row has a different one
'binnedInc', # Redundant with median income?
'PctSomeCol18_24', # contains null values - ignoring for now
'PctEmployed16_Over', # contains null values - ignoring for now
'PctPrivateCoverageAlone', # contains null values - ignoring for now
]
]
print(len(features), 'features')
28 features
x = df[features].values
y = df[[target]].values
print(x.shape, y.shape)
(3047, 28) (3047, 1)
class VariationalModel(nn.Module):
def __init__(
self,
n_inputs,
encoding_size,
n_hidden,
x_scaler,
y_scaler,
n_mixtures=5,
n_enc_layers=1,
n_dec_layers=1,
jitter=1e-8,
):
super().__init__()
self.n_inputs = n_inputs
self.encoding_size = encoding_size
self.n_mixtures = n_mixtures
self.jitter = jitter
self.x_scaler = x_scaler
self.y_scaler = y_scaler
# Prior
self.prior = torch.distributions.Independent(
torch.distributions.Normal(
torch.zeros(n_mixtures, encoding_size),
torch.ones(n_mixtures, encoding_size),
),
1,
)
# Encoder
enc_shared_layers = []
for i in range(n_enc_layers):
if i == 0:
layer = nn.Linear(n_inputs, n_hidden)
else:
layer = nn.Linear(n_hidden, n_hidden)
enc_shared_layers.append(layer)
enc_shared_layers.append(nn.ReLU())
self.enc_shared = nn.Sequential(*enc_shared_layers)
self.mean_gmm_layer = nn.Linear(n_hidden, encoding_size * n_mixtures)
self.std_gmm_layer = nn.Linear(n_hidden, encoding_size * n_mixtures)
# Decoder
dec_shared_layers = []
for i in range(n_dec_layers):
if i == 0:
layer = nn.Linear(encoding_size, n_hidden)
else:
layer = nn.Linear(n_hidden, n_hidden)
dec_shared_layers.append(layer)
dec_shared_layers.append(nn.ReLU())
self.dec_shared = nn.Sequential(*dec_shared_layers)
self.mean_layer = nn.Linear(n_hidden, 1)
self.std_layer = nn.Linear(n_hidden, 1)
def encode(self, x):
shared = self.x_scaler(x)
shared = self.enc_shared(shared)
shp = (-1, self.n_mixtures, self.encoding_size)
mean = torch.reshape(self.mean_gmm_layer(shared), shp)
std = torch.reshape(F.softplus(self.std_gmm_layer(shared)) + self.jitter, shp)
return torch.distributions.Independent(torch.distributions.Normal(mean, std), 1)
def decode(self, x_enc):
shared = self.dec_shared(x_enc)
mean = self.mean_layer(shared)
std = F.softplus(self.std_layer(shared)) + self.jitter
return torch.distributions.Normal(mean, std)
def forward(self, x):
gmm_dist = self.encode(x)
encoding = torch.mean(gmm_dist.mean, dim=1)
return self.decode(encoding)
def compute_loss(model, x, y):
"""
ELBO loss - https://en.wikipedia.org/wiki/Evidence_lower_bound
"""
y_scaled = model.y_scaler(y)
gmm_dist = model.encode(x)
encoding = torch.mean(gmm_dist.rsample(), dim=1)
y_hat = model.decode(encoding)
neg_log_likelihood = -y_hat.log_prob(y_scaled)
kl_divergence = torch.distributions.kl.kl_divergence(gmm_dist, model.prior)
return torch.mean(neg_log_likelihood + kl_divergence)
def compute_rmse(model, x_test, y_test):
model.eval()
y_hat = model(x_test)
pred = model.y_scaler.inverse_transform(y_hat.mean)
return torch.sqrt(torch.mean((pred - y_test)**2))
def train_one_step(model, optimizer, x_batch, y_batch):
model.train()
optimizer.zero_grad()
loss = compute_loss(model, x_batch, y_batch)
loss.backward()
optimizer.step()
return loss
def train(model, optimizer, x_train, x_val, y_train, y_val, n_epochs, batch_size, scheduler=None, print_every=10):
train_losses, val_losses = [], []
for epoch in range(n_epochs):
batch_indices = sample_batch_indices(x_train, y_train, batch_size)
batch_losses_t, batch_losses_v, batch_rmse_v = [], [], []
for batch_ix in batch_indices:
b_train_loss = train_one_step(model, optimizer, x_train[batch_ix], y_train[batch_ix])
model.eval()
b_val_loss = compute_loss(model, x_val, y_val)
b_val_rmse = compute_rmse(model, x_val, y_val)
batch_losses_t.append(b_train_loss.detach().numpy())
batch_losses_v.append(b_val_loss.detach().numpy())
batch_rmse_v.append(b_val_rmse.detach().numpy())
if scheduler is not None:
scheduler.step()
train_loss = np.mean(batch_losses_t)
val_loss = np.mean(batch_losses_v)
val_rmse = np.mean(batch_rmse_v)
train_losses.append(train_loss)
val_losses.append(val_loss)
if epoch == 0 or (epoch + 1) % print_every == 0:
print(f'Epoch {epoch+1} | Validation loss = {val_loss:.4f} | Validation RMSE = {val_rmse:.4f}')
_, ax = plt.subplots(1, 1, figsize=(12, 6))
ax.plot(range(1, n_epochs + 1), train_losses, label='Train loss')
ax.plot(range(1, n_epochs + 1), val_losses, label='Validation loss')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.set_title('Training Overview')
ax.legend()
return train_losses, val_losses
def sample_batch_indices(x, y, batch_size, rs=None):
if rs is None:
rs = np.random.RandomState()
train_ix = np.arange(len(x))
rs.shuffle(train_ix)
n_batches = int(np.ceil(len(x) / batch_size))
batch_indices = []
for i in range(n_batches):
start = i * batch_size
end = start + batch_size
batch_indices.append(
train_ix[start:end].tolist()
)
return batch_indices
class StandardScaler(object):
"""
Standardize data by removing the mean and scaling to unit variance.
"""
def __init__(self):
self.mean = None
self.scale = None
def fit(self, sample):
self.mean = sample.mean(0, keepdim=True)
self.scale = sample.std(0, unbiased=False, keepdim=True)
return self
def __call__(self, sample):
return self.transform(sample)
def transform(self, sample):
return (sample - self.mean) / self.scale
def inverse_transform(self, sample):
return sample * self.scale + self.mean
def compute_train_test_split(x, y, test_size):
x_train, x_test, y_train, y_test, train_ix, test_ix = train_test_split(
x,
y,
list(range(len(x))),
test_size=test_size,
)
return (
torch.from_numpy(x_train),
torch.from_numpy(x_test),
torch.from_numpy(y_train),
torch.from_numpy(y_test),
train_ix,
test_ix,
)
x_train, x_val, y_train, y_val, train_ix, test_ix = compute_train_test_split(x, y, test_size=0.2)
print(x_train.shape, y_train.shape, len(train_ix))
print(x_val.shape, y_val.shape, len(test_ix))
torch.Size([2437, 28]) torch.Size([2437, 1]) 2437 torch.Size([610, 28]) torch.Size([610, 1]) 610
x_scaler = StandardScaler().fit(torch.from_numpy(x))
y_scaler = StandardScaler().fit(torch.from_numpy(y))
model = VariationalModel(
n_inputs=x.shape[1],
encoding_size=10,
n_hidden=100,
x_scaler=x_scaler,
y_scaler=y_scaler,
n_mixtures=5,
n_enc_layers=1,
n_dec_layers=1,
)
model
VariationalModel( (enc_shared): Sequential( (0): Linear(in_features=28, out_features=100, bias=True) (1): ReLU() ) (mean_gmm_layer): Linear(in_features=100, out_features=50, bias=True) (std_gmm_layer): Linear(in_features=100, out_features=50, bias=True) (dec_shared): Sequential( (0): Linear(in_features=10, out_features=100, bias=True) (1): ReLU() ) (mean_layer): Linear(in_features=100, out_features=1, bias=True) (std_layer): Linear(in_features=100, out_features=1, bias=True) )
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'{pytorch_total_params:,} trainable parameters')
print()
14,302 trainable parameters
learning_rate = 1e-4
optimizer = torch.optim.Adam(
model.parameters(),
lr=learning_rate,
)
n_epochs = 1200
batch_size = 64
print_every = 50
train_losses, val_losses = train(
model,
optimizer,
x_train,
x_val,
y_train,
y_val,
n_epochs=n_epochs,
batch_size=batch_size,
print_every=print_every,
)
Epoch 1 | Validation loss = 3.1302 | Validation RMSE = 28.0765 Epoch 50 | Validation loss = 1.4088 | Validation RMSE = 22.9438 Epoch 100 | Validation loss = 1.3531 | Validation RMSE = 20.8623 Epoch 150 | Validation loss = 1.3479 | Validation RMSE = 20.3298 Epoch 200 | Validation loss = 1.3469 | Validation RMSE = 20.1943 Epoch 250 | Validation loss = 1.3388 | Validation RMSE = 20.0699 Epoch 300 | Validation loss = 1.3446 | Validation RMSE = 19.6804 Epoch 350 | Validation loss = 1.3457 | Validation RMSE = 19.6464 Epoch 400 | Validation loss = 1.3456 | Validation RMSE = 19.5523 Epoch 450 | Validation loss = 1.3483 | Validation RMSE = 19.6363 Epoch 500 | Validation loss = 1.3462 | Validation RMSE = 19.3257 Epoch 550 | Validation loss = 1.3380 | Validation RMSE = 19.2351 Epoch 600 | Validation loss = 1.3586 | Validation RMSE = 19.1674 Epoch 650 | Validation loss = 1.3522 | Validation RMSE = 19.0560 Epoch 700 | Validation loss = 1.3601 | Validation RMSE = 19.0464 Epoch 750 | Validation loss = 1.3578 | Validation RMSE = 18.7227 Epoch 800 | Validation loss = 1.3654 | Validation RMSE = 18.6508 Epoch 850 | Validation loss = 1.3640 | Validation RMSE = 18.6167 Epoch 900 | Validation loss = 1.3600 | Validation RMSE = 18.5071 Epoch 950 | Validation loss = 1.3726 | Validation RMSE = 18.4926 Epoch 1000 | Validation loss = 1.3745 | Validation RMSE = 18.4887 Epoch 1050 | Validation loss = 1.3723 | Validation RMSE = 18.2981 Epoch 1100 | Validation loss = 1.3841 | Validation RMSE = 18.2612 Epoch 1150 | Validation loss = 1.3823 | Validation RMSE = 18.5471 Epoch 1200 | Validation loss = 1.3951 | Validation RMSE = 18.5240
y_dist = model(x_val)
y_hat = model.y_scaler.inverse_transform(y_dist.mean)
val_rmse = float(compute_rmse(model, x_val, y_val).detach().numpy())
print(f'Validation RMSE = {val_rmse:.2f}')
Validation RMSE = 18.29
val_r2 = r2_score(
y_val.detach().numpy().flatten(),
y_hat.detach().numpy().flatten(),
)
print(f'Validation $R^2$ = {val_r2:.2f}')
Validation $R^2$ = 0.57
def plot_results(y_true, y_pred):
_, ax = plt.subplots(1, 1, figsize=(7, 7))
palette = sns.color_palette()
min_value = min(np.amin(y_true), np.amin(y_pred))
max_value = max(np.amax(y_true), np.amax(y_pred))
y_mid = np.linspace(min_value, max_value)
ax.plot(y_mid, y_mid, '--', color=palette[1])
ax.scatter(y_true, y_pred, color=palette[0], alpha=0.5);
return ax
ax = plot_results(
y_val.detach().numpy().flatten(),
y_hat.detach().numpy().flatten(),
);
ax.text(270, 120, f'$R^2 = {val_r2:.2f}$')
ax.text(270, 100, f'$RMSE = {val_rmse:.2f}$')
ax.set_xlabel('Actuals');
ax.set_ylabel('Predictions');
ax.set_title('Regression results on validation set');
def compute_first_order_uncertainty(model, x):
gmm_dist = model.encode(x)
encoding = torch.mean(gmm_dist.mean, dim=1)
out_dist = model.decode(encoding)
inv_tr = model.y_scaler.inverse_transform
y_hat = inv_tr(out_dist.mean)
std = inv_tr(out_dist.mean + out_dist.stddev) - y_hat
return std.detach().numpy()
first_order_std = compute_first_order_uncertainty(model, x_val)
first_order_std.shape
(610, 1)
def compute_second_order_uncertainty(model, x, n_samples=1000, std=False):
gmm_dist = model.encode(x)
encoding_sample = torch.mean(gmm_dist.rsample((n_samples,)), dim=2)
inv_tr = model.y_scaler.inverse_transform
shp = encoding_sample.shape
res = np.zeros((n_samples, shp[1], 1))
stds = np.zeros((n_samples, shp[1], 1))
for i in range(n_samples):
d = model.decode(encoding_sample[i])
mean = d.mean
res[i] = inv_tr(mean).detach().numpy()
stds[i] = (inv_tr(mean + d.stddev) - inv_tr(mean)).detach().numpy()
second_order_std = np.std(res, axis=0)
std_std = np.mean(stds, axis=0)
if not std:
return second_order_std
else:
return second_order_std, std_std
second_order_std, std_std = compute_second_order_uncertainty(model, x_val, std=True)
second_order_std.shape
(610, 1)
plt.hist(first_order_std.flatten(), bins=50);
plt.hist(second_order_std.flatten(), bins=50);
plt.hist(std_std.flatten(), bins=50);
We'll consider an input made of the maximum of all features.
max_input_np = np.zeros((1, x_val.shape[1]))
for i in range(x_val.shape[1]):
max_input_np[0, i] = torch.amax(x_val[:, i]) * 2
random_input = torch.from_numpy(max_input_np)
random_input.shape
torch.Size([1, 28])
r_first_order_std = compute_first_order_uncertainty(model, random_input)
r_second_order_std = compute_second_order_uncertainty(model, random_input)
print(f'First order uncertainty: {float(r_first_order_std):.2f}')
print(f'Second order uncertainty: {float(r_second_order_std):.2f}')
First order uncertainty: 28.76 Second order uncertainty: 9.42
An instance's expectedness can be measured by its average feature quantile.
We can then estimate the relation between uncertainty and expectedness.
def set_feature_quantile(df, features):
series = []
quantiles = np.arange(0., 1., 0.01)
for q in quantiles:
series.append(df[features].quantile(q))
def apply_quantile(row):
v = np.zeros((len(features)))
for i, f in enumerate(features):
value = row[f]
for j, q in enumerate(quantiles):
if value <= series[j][f]:
v[i] = q
break
return v.mean()
df['quantile'] = df.apply(apply_quantile, axis=1)
set_feature_quantile(df, features)
df['quantile'].head()
0 0.589286 1 0.413571 2 0.550000 3 0.568929 4 0.481786 Name: quantile, dtype: float64
df['quantile'].hist(bins=20);
def plot_expectedness_vs_uncertainty(df, test_ix, y_val, std):
expectedness = np.abs(df['quantile'].iloc[test_ix].values - 0.5)
stds = std.flatten() / y_val.detach().numpy().flatten()
_, ax = plt.subplots(1, 1, figsize=(12, 6))
ax.scatter(expectedness, stds, alpha=0.3)
X = expectedness.reshape((-1, 1))
y = stds
linreg = LinearRegression().fit(X, y)
print(f'R2: {linreg.score(X, y):.2f}')
print(f'Correlation: {np.corrcoef(expectedness, stds)[0, 1]:.2f}')
y_fit = linreg.coef_[0] * expectedness + linreg.intercept_
ax.plot(expectedness, y_fit, '-', color=palette[1], alpha=0.5)
return ax
ax = plot_expectedness_vs_uncertainty(df, test_ix, y_val, first_order_std)
R2: 0.01 Correlation: 0.09
ax = plot_expectedness_vs_uncertainty(df, test_ix, y_val, second_order_std)
R2: 0.01 Correlation: 0.12
def plot_absolute_error_vs_uncertainty(y_val, y_hat, std):
absolute_errors = (torch.abs(y_hat - y_val) / y_val).detach().numpy().flatten()
stds = std.flatten() / y_val.detach().numpy().flatten()
a = absolute_errors
b = stds
alpha = 0.3
_, ax = plt.subplots(1, 1, figsize=(8, 8))
ax.scatter(a, b, alpha=alpha)
X = a.reshape((-1, 1))
y = b
linreg = LinearRegression().fit(X, y)
print(f'R2: {linreg.score(X, y):.2f}')
print(f'Correlation: {np.corrcoef(a, b)[0, 1]:.2f}')
y_fit = linreg.coef_[0] * a + linreg.intercept_
ax.plot(a, y_fit, '-', color=palette[1], alpha=0.5)
return ax
plot_absolute_error_vs_uncertainty(y_val, y_hat, first_order_std);
R2: 0.23 Correlation: 0.48
plot_absolute_error_vs_uncertainty(y_val, y_hat, second_order_std);
R2: 0.24 Correlation: 0.49
plot_absolute_error_vs_uncertainty(y_val, y_hat, first_order_std + second_order_std);
R2: 0.24 Correlation: 0.49