Parametrization of a PyTorch distribution with a neural network, for the purpose of uncertainty quantification.
We'll consider the OLS Regression Challenge, which aims at predicting cancer mortality rates for US counties.
Notes:
import os
from os.path import join
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn.functional as F
from sklearn.metrics import r2_score
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split
from scipy.stats import binned_statistic
import statsmodels.api as sm
cwd = os.getcwd()
if cwd.endswith('notebook'):
os.chdir('..')
cwd = os.getcwd()
sns.set(palette='colorblind', font_scale=1.3)
palette = sns.color_palette()
seed = 456
np.random.seed(seed);
torch.manual_seed(seed);
torch.set_default_dtype(torch.float64)
df = pd.read_csv(join(cwd, 'data/cancer_reg.csv'))
df.head()
avgAnnCount | avgDeathsPerYear | TARGET_deathRate | incidenceRate | medIncome | popEst2015 | povertyPercent | studyPerCap | binnedInc | MedianAge | ... | PctPrivateCoverageAlone | PctEmpPrivCoverage | PctPublicCoverage | PctPublicCoverageAlone | PctWhite | PctBlack | PctAsian | PctOtherRace | PctMarriedHouseholds | BirthRate | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 1397.0 | 469 | 164.9 | 489.8 | 61898 | 260131 | 11.2 | 499.748204 | (61494.5, 125635] | 39.3 | ... | NaN | 41.6 | 32.9 | 14.0 | 81.780529 | 2.594728 | 4.821857 | 1.843479 | 52.856076 | 6.118831 |
1 | 173.0 | 70 | 161.3 | 411.6 | 48127 | 43269 | 18.6 | 23.111234 | (48021.6, 51046.4] | 33.0 | ... | 53.8 | 43.6 | 31.1 | 15.3 | 89.228509 | 0.969102 | 2.246233 | 3.741352 | 45.372500 | 4.333096 |
2 | 102.0 | 50 | 174.7 | 349.7 | 49348 | 21026 | 14.6 | 47.560164 | (48021.6, 51046.4] | 45.0 | ... | 43.5 | 34.9 | 42.1 | 21.1 | 90.922190 | 0.739673 | 0.465898 | 2.747358 | 54.444868 | 3.729488 |
3 | 427.0 | 202 | 194.8 | 430.4 | 44243 | 75882 | 17.1 | 342.637253 | (42724.4, 45201] | 42.8 | ... | 40.3 | 35.0 | 45.3 | 25.0 | 91.744686 | 0.782626 | 1.161359 | 1.362643 | 51.021514 | 4.603841 |
4 | 57.0 | 26 | 144.4 | 350.1 | 49955 | 10321 | 12.5 | 0.000000 | (48021.6, 51046.4] | 48.3 | ... | 43.9 | 35.1 | 44.0 | 22.7 | 94.104024 | 0.270192 | 0.665830 | 0.492135 | 54.027460 | 6.796657 |
5 rows × 34 columns
df.describe()
avgAnnCount | avgDeathsPerYear | TARGET_deathRate | incidenceRate | medIncome | popEst2015 | povertyPercent | studyPerCap | MedianAge | MedianAgeMale | ... | PctPrivateCoverageAlone | PctEmpPrivCoverage | PctPublicCoverage | PctPublicCoverageAlone | PctWhite | PctBlack | PctAsian | PctOtherRace | PctMarriedHouseholds | BirthRate | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
count | 3047.000000 | 3047.000000 | 3047.000000 | 3047.000000 | 3047.000000 | 3.047000e+03 | 3047.000000 | 3047.000000 | 3047.000000 | 3047.000000 | ... | 2438.000000 | 3047.000000 | 3047.000000 | 3047.000000 | 3047.000000 | 3047.000000 | 3047.000000 | 3047.000000 | 3047.000000 | 3047.000000 |
mean | 606.338544 | 185.965868 | 178.664063 | 448.268586 | 47063.281917 | 1.026374e+05 | 16.878175 | 155.399415 | 45.272333 | 39.570725 | ... | 48.453774 | 41.196324 | 36.252642 | 19.240072 | 83.645286 | 9.107978 | 1.253965 | 1.983523 | 51.243872 | 5.640306 |
std | 1416.356223 | 504.134286 | 27.751511 | 54.560733 | 12040.090836 | 3.290592e+05 | 6.409087 | 529.628366 | 45.304480 | 5.226017 | ... | 10.083006 | 9.447687 | 7.841741 | 6.113041 | 16.380025 | 14.534538 | 2.610276 | 3.517710 | 6.572814 | 1.985816 |
min | 6.000000 | 3.000000 | 59.700000 | 201.300000 | 22640.000000 | 8.270000e+02 | 3.200000 | 0.000000 | 22.300000 | 22.400000 | ... | 15.700000 | 13.500000 | 11.200000 | 2.600000 | 10.199155 | 0.000000 | 0.000000 | 0.000000 | 22.992490 | 0.000000 |
25% | 76.000000 | 28.000000 | 161.200000 | 420.300000 | 38882.500000 | 1.168400e+04 | 12.150000 | 0.000000 | 37.700000 | 36.350000 | ... | 41.000000 | 34.500000 | 30.900000 | 14.850000 | 77.296180 | 0.620675 | 0.254199 | 0.295172 | 47.763063 | 4.521419 |
50% | 171.000000 | 61.000000 | 178.100000 | 453.549422 | 45207.000000 | 2.664300e+04 | 15.900000 | 0.000000 | 41.000000 | 39.600000 | ... | 48.700000 | 41.100000 | 36.300000 | 18.800000 | 90.059774 | 2.247576 | 0.549812 | 0.826185 | 51.669941 | 5.381478 |
75% | 518.000000 | 149.000000 | 195.200000 | 480.850000 | 52492.000000 | 6.867100e+04 | 20.400000 | 83.650776 | 44.000000 | 42.500000 | ... | 55.600000 | 47.700000 | 41.550000 | 23.100000 | 95.451693 | 10.509732 | 1.221037 | 2.177960 | 55.395132 | 6.493677 |
max | 38150.000000 | 14010.000000 | 362.800000 | 1206.900000 | 125635.000000 | 1.017029e+07 | 47.400000 | 9762.308998 | 624.000000 | 64.700000 | ... | 78.900000 | 70.700000 | 65.100000 | 46.600000 | 100.000000 | 85.947799 | 42.619425 | 41.930251 | 78.075397 | 21.326165 |
8 rows × 32 columns
_, ax = plt.subplots(1, 1, figsize=(10, 5))
df['TARGET_deathRate'].hist(bins=50, ax=ax);
ax.set_title('Distribution of cancer death rate per 100,000 people');
ax.set_xlabel('Cancer death rate in county (per 100,000 people)');
ax.set_ylabel('Count');
target = 'TARGET_deathRate'
features = [
col for col in df.columns
if col not in [
target,
'Geography', # Label describing the county - each row has a different one
'binnedInc', # Redundant with median income?
'PctSomeCol18_24', # contains null values - ignoring for now
'PctEmployed16_Over', # contains null values - ignoring for now
'PctPrivateCoverageAlone', # contains null values - ignoring for now
]
]
print(len(features), 'features')
28 features
x = df[features].values
y = df[[target]].values
print(x.shape, y.shape)
(3047, 28) (3047, 1)
class DeepNormalModel(torch.nn.Module):
def __init__(
self,
n_inputs,
n_hidden,
x_scaler,
y_scaler,
):
super().__init__()
self.x_scaler = x_scaler
self.y_scaler = y_scaler
self.jitter = 1e-6
self.shared = torch.nn.Linear(n_inputs, n_hidden)
self.mean_hidden = torch.nn.Linear(n_hidden, n_hidden)
self.mean_linear = torch.nn.Linear(n_hidden, 1)
self.std_hidden = torch.nn.Linear(n_hidden, n_hidden)
self.std_linear = torch.nn.Linear(n_hidden, 1)
self.dropout = torch.nn.Dropout()
def forward(self, x):
# Normalization
shared = self.x_scaler(x)
# Shared layer
shared = self.shared(shared)
shared = F.relu(shared)
shared = self.dropout(shared)
# Parametrization of the mean
mean_hidden = self.mean_hidden(shared)
mean_hidden = F.relu(mean_hidden)
mean_hidden = self.dropout(mean_hidden)
mean = self.mean_linear(mean_hidden)
# Parametrization fo the standard deviation
std_hidden = self.std_hidden(shared)
std_hidden = F.relu(std_hidden)
std_hidden = self.dropout(std_hidden)
std = F.softplus(self.std_linear(std_hidden)) + self.jitter
return torch.distributions.Normal(mean, std)
def compute_loss(model, x, y, kl_reg=0.1):
y_scaled = model.y_scaler(y)
y_hat = model(x)
neg_log_likelihood = -y_hat.log_prob(y_scaled)
return torch.mean(neg_log_likelihood)
def compute_rmse(model, x_test, y_test):
model.eval()
y_hat = model(x_test)
pred = model.y_scaler.inverse_transform(y_hat.mean)
return torch.sqrt(torch.mean((pred - y_test)**2))
def train_one_step(model, optimizer, x_batch, y_batch):
model.train()
optimizer.zero_grad()
loss = compute_loss(model, x_batch, y_batch)
loss.backward()
optimizer.step()
return loss
def train(model, optimizer, x_train, x_val, y_train, y_val, n_epochs, batch_size, scheduler=None, print_every=10):
train_losses, val_losses = [], []
for epoch in range(n_epochs):
batch_indices = sample_batch_indices(x_train, y_train, batch_size)
batch_losses_t, batch_losses_v, batch_rmse_v = [], [], []
for batch_ix in batch_indices:
b_train_loss = train_one_step(model, optimizer, x_train[batch_ix], y_train[batch_ix])
model.eval()
b_val_loss = compute_loss(model, x_val, y_val)
b_val_rmse = compute_rmse(model, x_val, y_val)
batch_losses_t.append(b_train_loss.detach().numpy())
batch_losses_v.append(b_val_loss.detach().numpy())
batch_rmse_v.append(b_val_rmse.detach().numpy())
if scheduler is not None:
scheduler.step()
train_loss = np.mean(batch_losses_t)
val_loss = np.mean(batch_losses_v)
val_rmse = np.mean(batch_rmse_v)
train_losses.append(train_loss)
val_losses.append(val_loss)
if epoch == 0 or (epoch + 1) % print_every == 0:
print(f'Epoch {epoch+1} | Validation loss = {val_loss:.4f} | Validation RMSE = {val_rmse:.4f}')
_, ax = plt.subplots(1, 1, figsize=(12, 6))
ax.plot(range(1, n_epochs + 1), train_losses, label='Train loss')
ax.plot(range(1, n_epochs + 1), val_losses, label='Validation loss')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.set_title('Training Overview')
ax.legend()
return train_losses, val_losses
def sample_batch_indices(x, y, batch_size, rs=None):
if rs is None:
rs = np.random.RandomState()
train_ix = np.arange(len(x))
rs.shuffle(train_ix)
n_batches = int(np.ceil(len(x) / batch_size))
batch_indices = []
for i in range(n_batches):
start = i * batch_size
end = start + batch_size
batch_indices.append(
train_ix[start:end].tolist()
)
return batch_indices
class StandardScaler(object):
"""
Standardize data by removing the mean and scaling to unit variance.
"""
def __init__(self):
self.mean = None
self.scale = None
def fit(self, sample):
self.mean = sample.mean(0, keepdim=True)
self.scale = sample.std(0, unbiased=False, keepdim=True)
return self
def __call__(self, sample):
return self.transform(sample)
def transform(self, sample):
return (sample - self.mean) / self.scale
def inverse_transform(self, sample):
return sample * self.scale + self.mean
def compute_train_test_split(x, y, test_size):
x_train, x_test, y_train, y_test, train_ix, test_ix = train_test_split(
x,
y,
list(range(len(x))),
test_size=test_size,
)
return (
torch.from_numpy(x_train),
torch.from_numpy(x_test),
torch.from_numpy(y_train),
torch.from_numpy(y_test),
train_ix,
test_ix,
)
x_train, x_val, y_train, y_val, train_ix, test_ix = compute_train_test_split(x, y, test_size=0.2)
print(x_train.shape, y_train.shape, len(train_ix))
print(x_val.shape, y_val.shape, len(test_ix))
torch.Size([2437, 28]) torch.Size([2437, 1]) 2437 torch.Size([610, 28]) torch.Size([610, 1]) 610
x_scaler = StandardScaler().fit(torch.from_numpy(x))
y_scaler = StandardScaler().fit(torch.from_numpy(y))
%%time
learning_rate = 1e-3
momentum = 0.9
weight_decay = 1e-5
n_epochs = 300
batch_size = 64
print_every = 50
n_hidden = 100
model = DeepNormalModel(
n_inputs=x.shape[1],
n_hidden=n_hidden,
x_scaler=x_scaler,
y_scaler=y_scaler,
)
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'{pytorch_total_params:,} trainable parameters')
print()
optimizer = torch.optim.SGD(
model.parameters(),
lr=learning_rate,
momentum=momentum,
nesterov=True,
weight_decay=weight_decay,
)
scheduler = None
train_losses, val_losses = train(
model,
optimizer,
x_train,
x_val,
y_train,
y_val,
n_epochs=n_epochs,
batch_size=batch_size,
scheduler=scheduler,
print_every=print_every,
)
23,302 trainable parameters Epoch 1 | Validation loss = 1.3449 | Validation RMSE = 25.9272 Epoch 50 | Validation loss = 0.9299 | Validation RMSE = 18.2366 Epoch 100 | Validation loss = 0.9070 | Validation RMSE = 17.9943 Epoch 150 | Validation loss = 0.8992 | Validation RMSE = 17.8391 Epoch 200 | Validation loss = 0.8917 | Validation RMSE = 17.6591 Epoch 250 | Validation loss = 0.8986 | Validation RMSE = 17.6407 Epoch 300 | Validation loss = 0.8858 | Validation RMSE = 17.4868 CPU times: user 54.6 s, sys: 6.19 s, total: 1min Wall time: 57.1 s
y_dist = model(x_val)
y_hat = model.y_scaler.inverse_transform(y_dist.mean)
val_rmse = float(compute_rmse(model, x_val, y_val).detach().numpy())
print(f'Validation RMSE = {val_rmse:.2f}')
Validation RMSE = 17.54
val_r2 = r2_score(
y_val.detach().numpy().flatten(),
y_hat.detach().numpy().flatten(),
)
print(f'Validation $R^2$ = {val_r2:.2f}')
Validation $R^2$ = 0.55
def plot_results(y_true, y_pred):
f, ax = plt.subplots(1, 1, figsize=(7, 7))
palette = sns.color_palette()
min_value = min(np.amin(y_true), np.amin(y_pred))
max_value = max(np.amax(y_true), np.amax(y_pred))
y_mid = np.linspace(min_value, max_value)
ax.plot(y_mid, y_mid, '--', color=palette[1])
ax.scatter(y_true, y_pred, color=palette[0], alpha=0.5);
return f, ax
f, ax = plot_results(
y_val.detach().numpy().flatten(),
y_hat.detach().numpy().flatten(),
);
ax.text(225, 95, f'$R^2 = {val_r2:.2f}$')
ax.text(225, 80, f'$RMSE = {val_rmse:.2f}$')
ax.set_xlabel('Actuals');
ax.set_ylabel('Predictions');
ax.set_title('Regression results on validation set');
def make_predictions(model, x):
dist = model(x)
inv_tr = model.y_scaler.inverse_transform
y_hat = inv_tr(dist.mean)
# Recover standard deviation's original scale
std = inv_tr(dist.mean + dist.stddev) - y_hat
return y_hat, std
y_hat, std = make_predictions(model, x_val)
f, axes = plt.subplots(1, 2, figsize=(18, 6))
ax1, ax2 = axes
ax1.hist(y_hat.detach().numpy().flatten(), bins=50);
ax2.hist(std.detach().numpy().flatten(), bins=50, color=palette[1]);
ax1.set_xlabel('Cancer mortality rate (mean)');
ax1.set_ylabel('Instance count');
ax1.set_title('Distribution of mean $\mu$ on validation set');
ax2.set_xlabel('Cancer mortality rate (stddev)');
ax2.set_ylabel('Instance count');
ax2.set_title('Distribution of stddev $\sigma$ on validation set');
low_ix = [i for i, v in enumerate(y_hat.detach().numpy().flatten()) if float(v) < 140][0]
average_ix = [i for i, v in enumerate(y_hat.detach().numpy().flatten()) if 170 < float(v) < 190][0]
high_ix = [i for i, v in enumerate(y_hat.detach().numpy().flatten()) if float(v) > 220][0]
def plot_normal_distributions(y_hat, std, indices):
f, ax = plt.subplots(1, 1, figsize=(12, 6))
normal_fn = lambda z, m, s: (1/(s * np.sqrt(2*np.pi))) * np.exp(-0.5*((z - m) / s)**2)
x = np.linspace(80, 300, 300)
for ix in indices:
mu = y_hat.detach().numpy().flatten()[ix]
sigma = std.detach().numpy().flatten()[ix]
ax.plot(x, normal_fn(x, mu, sigma), label=rf'$\mu = {mu:.0f}$, $\sigma = {sigma:.1f}$')
ax.set_xlabel('Cancer mortality rate ($y$)');
ax.set_ylabel(r'$P(y \mid \mu, \sigma)$');
ax.set_title('PDF of three instances of the validation set');
ax.legend(loc='upper left');
return f, ax
f, ax = plot_normal_distributions(y_hat, std, [low_ix, average_ix, high_ix])
def plot_absolute_error_vs_uncertainty(y_val, y_hat, std):
absolute_errors = 100 * (torch.abs(y_hat - y_val) / y_val).detach().numpy().flatten()
stds = std.detach().numpy().flatten()
a = absolute_errors
b = stds
alpha = 0.5
f, ax = plt.subplots(1, 1, figsize=(7, 7))
ax.scatter(a, b, alpha=alpha)
X = a.reshape((-1, 1))
y = b
linreg = LinearRegression().fit(X, y)
r2 = linreg.score(X, y)
corr = np.corrcoef(a, b)[0, 1]
print(f'R2: {r2:.2f}')
print(f'Correlation: {corr:.2f}')
y_fit = linreg.coef_[0] * a + linreg.intercept_
ax.plot(a, y_fit, '-', color=palette[1], alpha=0.5)
ax.text(56, 7.5, f'Corr = {corr:.2f}')
return f, ax
f, ax = plot_absolute_error_vs_uncertainty(y_val, y_hat, std)
ax.set_title('Error vs uncertainty');
ax.set_xlabel('Absolute error / actual (%)');
ax.set_ylabel('Standard deviation');
R2: 0.08 Correlation: 0.28
We'll consider an input made of the maximum of all features.
max_input_np = np.zeros((1, x_val.shape[1]))
for i in range(x_val.shape[1]):
max_input_np[0, i] = torch.amax(x_val[:, i]) * 2
max_input = torch.from_numpy(max_input_np)
max_input.shape
torch.Size([1, 28])
r_y_hat, r_first_order_std = make_predictions(model, max_input)
print(f'First order uncertainty: {float((r_first_order_std / r_y_hat).detach().numpy()):.2f}')
First order uncertainty: 2.59
An instance's expectedness can be measured by its average feature quantile minus the median 0.5.
We can then estimate the relation between uncertainty and expectedness.
def set_feature_quantile(df, features):
series = []
quantiles = np.arange(0., 1., 0.01)
for q in quantiles:
series.append(df[features].quantile(q))
def apply_quantile(row):
v = np.zeros((len(features)))
for i, f in enumerate(features):
value = row[f]
for j, q in enumerate(quantiles):
if value <= series[j][f]:
v[i] = q
break
return v.mean()
df['quantile'] = df.apply(apply_quantile, axis=1)
set_feature_quantile(df, features)
df['quantile'].hist(bins=20);
def plot_expectedness_vs_uncertainty(df, test_ix, y_val, std):
quantiles = df['quantile'].iloc[test_ix].values
expectedness = 100 * np.abs(quantiles - 0.5)
stds = std.detach().numpy().flatten()
f, ax = plt.subplots(1, 1, figsize=(7, 7))
ax.scatter(expectedness, stds, alpha=0.5)
X = sm.add_constant(expectedness.reshape((-1, 1)))
y = stds.reshape((-1, 1))
linreg = sm.OLS(y, X).fit()
y_fit = linreg.predict(X).flatten()
corr = np.corrcoef(expectedness, stds)[0, 1]
print(f'Correlation: {corr:.2f}')
print(linreg.summary())
ax.plot(expectedness, y_fit, '-', color=palette[1], alpha=0.5)
ax.text(27, 7.5, f'Corr = {corr:.2f}')
return f, ax
f, ax = plot_expectedness_vs_uncertainty(df, test_ix, y_val, std)
ax.set_xlabel('Unexpectedness (%)')
ax.set_ylabel('Standard deviation')
ax.set_title('Unexpectedness vs uncertainty');
Correlation: 0.33 OLS Regression Results ============================================================================== Dep. Variable: y R-squared: 0.108 Model: OLS Adj. R-squared: 0.107 Method: Least Squares F-statistic: 73.98 Date: Fri, 07 Jan 2022 Prob (F-statistic): 6.72e-17 Time: 11:36:17 Log-Likelihood: -1754.4 No. Observations: 610 AIC: 3513. Df Residuals: 608 BIC: 3522. Df Model: 1 Covariance Type: nonrobust ============================================================================== coef std err t P>|t| [0.025 0.975] ------------------------------------------------------------------------------ const 12.7535 0.281 45.323 0.000 12.201 13.306 x1 0.3532 0.041 8.601 0.000 0.273 0.434 ============================================================================== Omnibus: 145.715 Durbin-Watson: 2.140 Prob(Omnibus): 0.000 Jarque-Bera (JB): 468.950 Skew: 1.117 Prob(JB): 1.48e-102 Kurtosis: 6.669 Cond. No. 11.2 ============================================================================== Notes: [1] Standard Errors assume that the covariance matrix of the errors is correctly specified.