Stochastic Variational Inference

Implementation of Stochastic Variational Inference (SVI) using PyTorch, for the purpose of uncertainty quantification.

We'll consider the OLS Regression Challenge, which aims at predicting cancer mortality rates for US counties.

Data Dictionary

  • TARGET_deathRate: Dependent variable. Mean per capita (100,000) cancer mortalities (a)
  • avgAnnCount: Mean number of reported cases of cancer diagnosed annually (a)
  • avgDeathsPerYear: Mean number of reported mortalities due to cancer (a)
  • incidenceRate: Mean per capita (100,000) cancer diagoses (a)
  • medianIncome: Median income per county (b)
  • popEst2015: Population of county (b)
  • povertyPercent: Percent of populace in poverty (b)
  • studyPerCap: Per capita number of cancer-related clinical trials per county (a)
  • binnedInc: Median income per capita binned by decile (b)
  • MedianAge: Median age of county residents (b)
  • MedianAgeMale: Median age of male county residents (b)
  • MedianAgeFemale: Median age of female county residents (b)
  • Geography: County name (b)
  • AvgHouseholdSize: Mean household size of county (b)
  • PercentMarried: Percent of county residents who are married (b)
  • PctNoHS18_24: Percent of county residents ages 18-24 highest education attained: less than high school (b)
  • PctHS18_24: Percent of county residents ages 18-24 highest education attained: high school diploma (b)
  • PctSomeCol18_24: Percent of county residents ages 18-24 highest education attained: some college (b)
  • PctBachDeg18_24: Percent of county residents ages 18-24 highest education attained: bachelor's degree (b)
  • PctHS25_Over: Percent of county residents ages 25 and over highest education attained: high school diploma (b)
  • PctBachDeg25_Over: Percent of county residents ages 25 and over highest education attained: bachelor's degree (b)
  • PctEmployed16_Over: Percent of county residents ages 16 and over employed (b)
  • PctUnemployed16_Over: Percent of county residents ages 16 and over unemployed (b)
  • PctPrivateCoverage: Percent of county residents with private health coverage (b)
  • PctPrivateCoverageAlone: Percent of county residents with private health coverage alone (no public assistance) (b)
  • PctEmpPrivCoverage: Percent of county residents with employee-provided private health coverage (b)
  • PctPublicCoverage: Percent of county residents with government-provided health coverage (b)
  • PctPubliceCoverageAlone: Percent of county residents with government-provided health coverage alone (b)
  • PctWhite: Percent of county residents who identify as White (b)
  • PctBlack: Percent of county residents who identify as Black (b)
  • PctAsian: Percent of county residents who identify as Asian (b)
  • PctOtherRace: Percent of county residents who identify in a category which is not White, Black, or Asian (b)
  • PctMarriedHouseholds: Percent of married households (b)
  • BirthRate: Number of live births relative to number of women in county (b)

Notes:

  • (a): years 2010-2016
  • (b): 2013 Census Estimates
In [1]:
import os
from os.path import join

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from torch import nn
import torch.nn.functional as F
from sklearn.metrics import r2_score
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from scipy.stats import binned_statistic

cwd = os.getcwd()
if cwd.endswith('notebook'):
    os.chdir('..')
    cwd = os.getcwd()
In [2]:
sns.set(palette='colorblind', font_scale=1.3)
palette = sns.color_palette()
In [3]:
seed = 444
np.random.seed(seed);
torch.manual_seed(seed);
torch.set_default_dtype(torch.float64)

Dataset

Load

In [4]:
df = pd.read_csv(join(cwd, 'data/cancer_reg.csv'))
df.head()
Out[4]:
avgAnnCount avgDeathsPerYear TARGET_deathRate incidenceRate medIncome popEst2015 povertyPercent studyPerCap binnedInc MedianAge ... PctPrivateCoverageAlone PctEmpPrivCoverage PctPublicCoverage PctPublicCoverageAlone PctWhite PctBlack PctAsian PctOtherRace PctMarriedHouseholds BirthRate
0 1397.0 469 164.9 489.8 61898 260131 11.2 499.748204 (61494.5, 125635] 39.3 ... NaN 41.6 32.9 14.0 81.780529 2.594728 4.821857 1.843479 52.856076 6.118831
1 173.0 70 161.3 411.6 48127 43269 18.6 23.111234 (48021.6, 51046.4] 33.0 ... 53.8 43.6 31.1 15.3 89.228509 0.969102 2.246233 3.741352 45.372500 4.333096
2 102.0 50 174.7 349.7 49348 21026 14.6 47.560164 (48021.6, 51046.4] 45.0 ... 43.5 34.9 42.1 21.1 90.922190 0.739673 0.465898 2.747358 54.444868 3.729488
3 427.0 202 194.8 430.4 44243 75882 17.1 342.637253 (42724.4, 45201] 42.8 ... 40.3 35.0 45.3 25.0 91.744686 0.782626 1.161359 1.362643 51.021514 4.603841
4 57.0 26 144.4 350.1 49955 10321 12.5 0.000000 (48021.6, 51046.4] 48.3 ... 43.9 35.1 44.0 22.7 94.104024 0.270192 0.665830 0.492135 54.027460 6.796657

5 rows × 34 columns

In [5]:
df.describe()
Out[5]:
avgAnnCount avgDeathsPerYear TARGET_deathRate incidenceRate medIncome popEst2015 povertyPercent studyPerCap MedianAge MedianAgeMale ... PctPrivateCoverageAlone PctEmpPrivCoverage PctPublicCoverage PctPublicCoverageAlone PctWhite PctBlack PctAsian PctOtherRace PctMarriedHouseholds BirthRate
count 3047.000000 3047.000000 3047.000000 3047.000000 3047.000000 3.047000e+03 3047.000000 3047.000000 3047.000000 3047.000000 ... 2438.000000 3047.000000 3047.000000 3047.000000 3047.000000 3047.000000 3047.000000 3047.000000 3047.000000 3047.000000
mean 606.338544 185.965868 178.664063 448.268586 47063.281917 1.026374e+05 16.878175 155.399415 45.272333 39.570725 ... 48.453774 41.196324 36.252642 19.240072 83.645286 9.107978 1.253965 1.983523 51.243872 5.640306
std 1416.356223 504.134286 27.751511 54.560733 12040.090836 3.290592e+05 6.409087 529.628366 45.304480 5.226017 ... 10.083006 9.447687 7.841741 6.113041 16.380025 14.534538 2.610276 3.517710 6.572814 1.985816
min 6.000000 3.000000 59.700000 201.300000 22640.000000 8.270000e+02 3.200000 0.000000 22.300000 22.400000 ... 15.700000 13.500000 11.200000 2.600000 10.199155 0.000000 0.000000 0.000000 22.992490 0.000000
25% 76.000000 28.000000 161.200000 420.300000 38882.500000 1.168400e+04 12.150000 0.000000 37.700000 36.350000 ... 41.000000 34.500000 30.900000 14.850000 77.296180 0.620675 0.254199 0.295172 47.763063 4.521419
50% 171.000000 61.000000 178.100000 453.549422 45207.000000 2.664300e+04 15.900000 0.000000 41.000000 39.600000 ... 48.700000 41.100000 36.300000 18.800000 90.059774 2.247576 0.549812 0.826185 51.669941 5.381478
75% 518.000000 149.000000 195.200000 480.850000 52492.000000 6.867100e+04 20.400000 83.650776 44.000000 42.500000 ... 55.600000 47.700000 41.550000 23.100000 95.451693 10.509732 1.221037 2.177960 55.395132 6.493677
max 38150.000000 14010.000000 362.800000 1206.900000 125635.000000 1.017029e+07 47.400000 9762.308998 624.000000 64.700000 ... 78.900000 70.700000 65.100000 46.600000 100.000000 85.947799 42.619425 41.930251 78.075397 21.326165

8 rows × 32 columns

Plot distribution of target variable

In [6]:
_, ax = plt.subplots(1, 1, figsize=(10, 5))
df['TARGET_deathRate'].hist(bins=50, ax=ax);
ax.set_title('Distribution of cancer death rate per 100,000 people');
ax.set_xlabel('Cancer death rate in county (per 100,000 people)');
ax.set_ylabel('Count');

Define feature & target variables

In [7]:
target = 'TARGET_deathRate'
features = [
    col for col in df.columns
    if col not in [
        target, 
        'Geography',                # Label describing the county - each row has a different one
        'binnedInc',                # Redundant with median income?
        'PctSomeCol18_24',          # contains null values - ignoring for now
        'PctEmployed16_Over',       # contains null values - ignoring for now
        'PctPrivateCoverageAlone',  # contains null values - ignoring for now
    ]
]
print(len(features), 'features')
28 features
In [8]:
x = df[features].values
y = df[[target]].values
print(x.shape, y.shape)
(3047, 28) (3047, 1)

Define model

In [9]:
class VariationalModel(nn.Module):
    def __init__(
        self, 
        n_inputs,
        encoding_size,
        n_hidden,
        x_scaler, 
        y_scaler,
        n_mixtures=5,
        n_enc_layers=1,
        n_dec_layers=1,
        jitter=1e-8,
    ):
        super().__init__()
        
        self.n_inputs = n_inputs
        self.encoding_size = encoding_size
        self.n_mixtures = n_mixtures
        self.jitter = jitter
        self.x_scaler = x_scaler
        self.y_scaler = y_scaler
        
        # Prior
        self.prior = torch.distributions.Independent(
            torch.distributions.Normal(
                torch.zeros(n_mixtures, encoding_size), 
                torch.ones(n_mixtures, encoding_size),
            ), 
            1,
        )
        
        # Encoder
        enc_shared_layers = []
        for i in range(n_enc_layers):
            if i == 0:
                layer = nn.Linear(n_inputs, n_hidden)
            else:
                layer = nn.Linear(n_hidden, n_hidden)
    
            enc_shared_layers.append(layer)
            enc_shared_layers.append(nn.ReLU())
            
        self.enc_shared = nn.Sequential(*enc_shared_layers)
        self.mean_gmm_layer = nn.Linear(n_hidden, encoding_size * n_mixtures)
        self.std_gmm_layer = nn.Linear(n_hidden, encoding_size * n_mixtures)
        
        # Decoder
        dec_shared_layers = []
        for i in range(n_dec_layers):
            if i == 0:
                layer = nn.Linear(encoding_size, n_hidden)
            else:
                layer = nn.Linear(n_hidden, n_hidden)
                
            dec_shared_layers.append(layer)
            dec_shared_layers.append(nn.ReLU())
            
        self.dec_shared = nn.Sequential(*dec_shared_layers)
        self.mean_layer = nn.Linear(n_hidden, 1)
        self.std_layer = nn.Linear(n_hidden, 1)
        
    def encode(self, x):
        shared = self.x_scaler(x)
        shared = self.enc_shared(shared)
        
        shp = (-1, self.n_mixtures, self.encoding_size)
        mean = torch.reshape(self.mean_gmm_layer(shared), shp)
        std = torch.reshape(F.softplus(self.std_gmm_layer(shared)) + self.jitter, shp)
        
        return torch.distributions.Independent(torch.distributions.Normal(mean, std), 1)
    
    def decode(self, x_enc):
        shared = self.dec_shared(x_enc)
        mean = self.mean_layer(shared)
        std = F.softplus(self.std_layer(shared)) + self.jitter
        return torch.distributions.Normal(mean, std)
    
    def forward(self, x):
        gmm_dist = self.encode(x)
        encoding = torch.mean(gmm_dist.mean, dim=1)
        return self.decode(encoding)
In [10]:
def compute_loss(model, x, y):
    """
    ELBO loss - https://en.wikipedia.org/wiki/Evidence_lower_bound
    """
    y_scaled = model.y_scaler(y)
    
    gmm_dist = model.encode(x)
    encoding = torch.mean(gmm_dist.rsample(), dim=1)
    y_hat = model.decode(encoding)
    
    neg_log_likelihood = -y_hat.log_prob(y_scaled)
    
    kl_divergence = torch.distributions.kl.kl_divergence(gmm_dist, model.prior)
    
    return torch.mean(neg_log_likelihood + kl_divergence)


def compute_rmse(model, x_test, y_test):
    model.eval()
    y_hat = model(x_test)
    pred = model.y_scaler.inverse_transform(y_hat.mean)
    return torch.sqrt(torch.mean((pred - y_test)**2))


def train_one_step(model, optimizer, x_batch, y_batch):
    model.train()
    optimizer.zero_grad()
    loss = compute_loss(model, x_batch, y_batch)
    loss.backward()
    optimizer.step()
    return loss
In [11]:
def train(model, optimizer, x_train, x_val, y_train, y_val, n_epochs, batch_size, scheduler=None, print_every=10):
    train_losses, val_losses = [], []
    for epoch in range(n_epochs):
        batch_indices = sample_batch_indices(x_train, y_train, batch_size)
        
        batch_losses_t, batch_losses_v, batch_rmse_v = [], [], []
        for batch_ix in batch_indices:
            b_train_loss = train_one_step(model, optimizer, x_train[batch_ix], y_train[batch_ix])

            model.eval()
            b_val_loss = compute_loss(model, x_val, y_val)
            b_val_rmse = compute_rmse(model, x_val, y_val)

            batch_losses_t.append(b_train_loss.detach().numpy())
            batch_losses_v.append(b_val_loss.detach().numpy())
            batch_rmse_v.append(b_val_rmse.detach().numpy())
            
        if scheduler is not None:
            scheduler.step()
            
        train_loss = np.mean(batch_losses_t)
        val_loss = np.mean(batch_losses_v)
        val_rmse = np.mean(batch_rmse_v)
        
        train_losses.append(train_loss)
        val_losses.append(val_loss)

        if epoch == 0 or (epoch + 1) % print_every == 0:
            print(f'Epoch {epoch+1} | Validation loss = {val_loss:.4f} | Validation RMSE = {val_rmse:.4f}')
        
    _, ax = plt.subplots(1, 1, figsize=(12, 6))
    ax.plot(range(1, n_epochs + 1), train_losses, label='Train loss')
    ax.plot(range(1, n_epochs + 1), val_losses, label='Validation loss')
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Loss')
    ax.set_title('Training Overview')
    ax.legend()
    
    return train_losses, val_losses


def sample_batch_indices(x, y, batch_size, rs=None):
    if rs is None:
        rs = np.random.RandomState()
    
    train_ix = np.arange(len(x))
    rs.shuffle(train_ix)
    
    n_batches = int(np.ceil(len(x) / batch_size))
    
    batch_indices = []
    for i in range(n_batches):
        start = i * batch_size
        end = start + batch_size
        batch_indices.append(
            train_ix[start:end].tolist()
        )

    return batch_indices
In [12]:
class StandardScaler(object):
    """
    Standardize data by removing the mean and scaling to unit variance.
    """
    def __init__(self):
        self.mean = None
        self.scale = None

    def fit(self, sample):
        self.mean = sample.mean(0, keepdim=True)
        self.scale = sample.std(0, unbiased=False, keepdim=True)
        return self

    def __call__(self, sample):
        return self.transform(sample)
    
    def transform(self, sample):
        return (sample - self.mean) / self.scale

    def inverse_transform(self, sample):
        return sample * self.scale + self.mean

Split between training and validation sets

In [13]:
def compute_train_test_split(x, y, test_size):
    x_train, x_test, y_train, y_test, train_ix, test_ix = train_test_split(
        x, 
        y, 
        list(range(len(x))),
        test_size=test_size,
    )
    return (
        torch.from_numpy(x_train),
        torch.from_numpy(x_test),
        torch.from_numpy(y_train),
        torch.from_numpy(y_test),
        train_ix,
        test_ix,
    )
In [14]:
x_train, x_val, y_train, y_val, train_ix, test_ix = compute_train_test_split(x, y, test_size=0.2)

print(x_train.shape, y_train.shape, len(train_ix))
print(x_val.shape, y_val.shape, len(test_ix))
torch.Size([2437, 28]) torch.Size([2437, 1]) 2437
torch.Size([610, 28]) torch.Size([610, 1]) 610

Train model

In [15]:
x_scaler = StandardScaler().fit(torch.from_numpy(x))
y_scaler = StandardScaler().fit(torch.from_numpy(y))
In [16]:
model = VariationalModel(
    n_inputs=x.shape[1],
    encoding_size=10,
    n_hidden=100,
    x_scaler=x_scaler, 
    y_scaler=y_scaler,
    n_mixtures=5,
    n_enc_layers=1,
    n_dec_layers=1,
)
model
Out[16]:
VariationalModel(
  (enc_shared): Sequential(
    (0): Linear(in_features=28, out_features=100, bias=True)
    (1): ReLU()
  )
  (mean_gmm_layer): Linear(in_features=100, out_features=50, bias=True)
  (std_gmm_layer): Linear(in_features=100, out_features=50, bias=True)
  (dec_shared): Sequential(
    (0): Linear(in_features=10, out_features=100, bias=True)
    (1): ReLU()
  )
  (mean_layer): Linear(in_features=100, out_features=1, bias=True)
  (std_layer): Linear(in_features=100, out_features=1, bias=True)
)
In [17]:
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'{pytorch_total_params:,} trainable parameters')
print()
14,302 trainable parameters

In [18]:
learning_rate = 1e-4

optimizer = torch.optim.Adam(
    model.parameters(), 
    lr=learning_rate, 
)
In [19]:
n_epochs = 1200
batch_size = 64
print_every = 50

train_losses, val_losses = train(
    model, 
    optimizer, 
    x_train, 
    x_val, 
    y_train, 
    y_val, 
    n_epochs=n_epochs, 
    batch_size=batch_size, 
    print_every=print_every,
)
Epoch 1 | Validation loss = 3.1302 | Validation RMSE = 28.0765
Epoch 50 | Validation loss = 1.4088 | Validation RMSE = 22.9438
Epoch 100 | Validation loss = 1.3531 | Validation RMSE = 20.8623
Epoch 150 | Validation loss = 1.3479 | Validation RMSE = 20.3298
Epoch 200 | Validation loss = 1.3469 | Validation RMSE = 20.1943
Epoch 250 | Validation loss = 1.3388 | Validation RMSE = 20.0699
Epoch 300 | Validation loss = 1.3446 | Validation RMSE = 19.6804
Epoch 350 | Validation loss = 1.3457 | Validation RMSE = 19.6464
Epoch 400 | Validation loss = 1.3456 | Validation RMSE = 19.5523
Epoch 450 | Validation loss = 1.3483 | Validation RMSE = 19.6363
Epoch 500 | Validation loss = 1.3462 | Validation RMSE = 19.3257
Epoch 550 | Validation loss = 1.3380 | Validation RMSE = 19.2351
Epoch 600 | Validation loss = 1.3586 | Validation RMSE = 19.1674
Epoch 650 | Validation loss = 1.3522 | Validation RMSE = 19.0560
Epoch 700 | Validation loss = 1.3601 | Validation RMSE = 19.0464
Epoch 750 | Validation loss = 1.3578 | Validation RMSE = 18.7227
Epoch 800 | Validation loss = 1.3654 | Validation RMSE = 18.6508
Epoch 850 | Validation loss = 1.3640 | Validation RMSE = 18.6167
Epoch 900 | Validation loss = 1.3600 | Validation RMSE = 18.5071
Epoch 950 | Validation loss = 1.3726 | Validation RMSE = 18.4926
Epoch 1000 | Validation loss = 1.3745 | Validation RMSE = 18.4887
Epoch 1050 | Validation loss = 1.3723 | Validation RMSE = 18.2981
Epoch 1100 | Validation loss = 1.3841 | Validation RMSE = 18.2612
Epoch 1150 | Validation loss = 1.3823 | Validation RMSE = 18.5471
Epoch 1200 | Validation loss = 1.3951 | Validation RMSE = 18.5240

Validation

In [20]:
y_dist = model(x_val)
y_hat = model.y_scaler.inverse_transform(y_dist.mean)
In [21]:
val_rmse = float(compute_rmse(model, x_val, y_val).detach().numpy())
print(f'Validation RMSE = {val_rmse:.2f}')
Validation RMSE = 18.29
In [22]:
val_r2 = r2_score(
    y_val.detach().numpy().flatten(),
    y_hat.detach().numpy().flatten(),
)
print(f'Validation $R^2$ = {val_r2:.2f}')
Validation $R^2$ = 0.57
In [23]:
def plot_results(y_true, y_pred):
    _, ax = plt.subplots(1, 1, figsize=(7, 7))
    palette = sns.color_palette()
    
    min_value = min(np.amin(y_true), np.amin(y_pred))
    max_value = max(np.amax(y_true), np.amax(y_pred))
    y_mid = np.linspace(min_value, max_value)
    
    ax.plot(y_mid, y_mid, '--', color=palette[1])
    ax.scatter(y_true, y_pred, color=palette[0], alpha=0.5);
    
    return ax
In [24]:
ax = plot_results(
    y_val.detach().numpy().flatten(),
    y_hat.detach().numpy().flatten(),
);

ax.text(270, 120, f'$R^2 = {val_r2:.2f}$')
ax.text(270, 100, f'$RMSE = {val_rmse:.2f}$')

ax.set_xlabel('Actuals');
ax.set_ylabel('Predictions');
ax.set_title('Regression results on validation set');

Uncertainty

In [25]:
def compute_first_order_uncertainty(model, x):
    gmm_dist = model.encode(x)
    encoding = torch.mean(gmm_dist.mean, dim=1)
    out_dist = model.decode(encoding)
    
    inv_tr = model.y_scaler.inverse_transform
    y_hat = inv_tr(out_dist.mean)
    std = inv_tr(out_dist.mean + out_dist.stddev) - y_hat
    
    return std.detach().numpy()
In [26]:
first_order_std = compute_first_order_uncertainty(model, x_val)
first_order_std.shape
Out[26]:
(610, 1)
In [27]:
def compute_second_order_uncertainty(model, x, n_samples=1000, std=False):
    gmm_dist = model.encode(x)
    encoding_sample = torch.mean(gmm_dist.rsample((n_samples,)), dim=2)
    
    inv_tr = model.y_scaler.inverse_transform
    
    shp = encoding_sample.shape
    res = np.zeros((n_samples, shp[1], 1))
    stds = np.zeros((n_samples, shp[1], 1))
    for i in range(n_samples):
        d = model.decode(encoding_sample[i])
        mean = d.mean
        res[i] = inv_tr(mean).detach().numpy()
        stds[i] = (inv_tr(mean + d.stddev) - inv_tr(mean)).detach().numpy()
        
    second_order_std = np.std(res, axis=0)
    std_std = np.mean(stds, axis=0)
        
    if not std:    
        return second_order_std
    else:
        return second_order_std, std_std
In [28]:
second_order_std, std_std = compute_second_order_uncertainty(model, x_val, std=True)
second_order_std.shape
Out[28]:
(610, 1)
In [29]:
plt.hist(first_order_std.flatten(), bins=50);
In [30]:
plt.hist(second_order_std.flatten(), bins=50);
In [31]:
plt.hist(std_std.flatten(), bins=50);

Uncertainty on an input not yet seen & very different from the rest of the dataset

We'll consider an input made of the maximum of all features.

In [32]:
max_input_np = np.zeros((1, x_val.shape[1]))
for i in range(x_val.shape[1]):
    max_input_np[0, i] = torch.amax(x_val[:, i]) * 2

random_input = torch.from_numpy(max_input_np)
random_input.shape
Out[32]:
torch.Size([1, 28])
In [33]:
r_first_order_std = compute_first_order_uncertainty(model, random_input)
r_second_order_std = compute_second_order_uncertainty(model, random_input)

print(f'First order uncertainty: {float(r_first_order_std):.2f}')
print(f'Second order uncertainty: {float(r_second_order_std):.2f}')
First order uncertainty: 28.76
Second order uncertainty: 9.42

Quantiles

An instance's expectedness can be measured by its average feature quantile.

We can then estimate the relation between uncertainty and expectedness.

In [34]:
def set_feature_quantile(df, features):
    series = []
    quantiles = np.arange(0., 1., 0.01)
    for q in quantiles:
        series.append(df[features].quantile(q))
        
    def apply_quantile(row):
        v = np.zeros((len(features)))
        
        for i, f in enumerate(features):
            value = row[f]
            for j, q in enumerate(quantiles):
                if value <= series[j][f]:
                    v[i] = q
                    break
        
        return v.mean()
    
    df['quantile'] = df.apply(apply_quantile, axis=1)
In [35]:
set_feature_quantile(df, features)
In [36]:
df['quantile'].head()
Out[36]:
0    0.589286
1    0.413571
2    0.550000
3    0.568929
4    0.481786
Name: quantile, dtype: float64
In [37]:
df['quantile'].hist(bins=20);
In [49]:
def plot_expectedness_vs_uncertainty(df, test_ix, y_val, std):
    expectedness = np.abs(df['quantile'].iloc[test_ix].values - 0.5)
    stds = std.flatten() / y_val.detach().numpy().flatten()
    
    _, ax = plt.subplots(1, 1, figsize=(12, 6))
    ax.scatter(expectedness, stds, alpha=0.3)
    
    X = expectedness.reshape((-1, 1))
    y = stds
    linreg = LinearRegression().fit(X, y)
    print(f'R2: {linreg.score(X, y):.2f}')
    print(f'Correlation: {np.corrcoef(expectedness, stds)[0, 1]:.2f}')
    
    y_fit = linreg.coef_[0] * expectedness + linreg.intercept_
    
    ax.plot(expectedness, y_fit, '-', color=palette[1], alpha=0.5)
    
    return ax
In [51]:
ax = plot_expectedness_vs_uncertainty(df, test_ix, y_val, first_order_std)
R2: 0.01
Correlation: 0.09
In [52]:
ax = plot_expectedness_vs_uncertainty(df, test_ix, y_val, second_order_std)
R2: 0.01
Correlation: 0.12

Uncertainty vs absolute error

In [45]:
def plot_absolute_error_vs_uncertainty(y_val, y_hat, std):
    absolute_errors = (torch.abs(y_hat - y_val) / y_val).detach().numpy().flatten()
    stds = std.flatten() / y_val.detach().numpy().flatten()
    
    a = absolute_errors
    b = stds
    alpha = 0.3
    
    _, ax = plt.subplots(1, 1, figsize=(8, 8))
    ax.scatter(a, b, alpha=alpha)
    
    X = a.reshape((-1, 1))
    y = b
    linreg = LinearRegression().fit(X, y)
    print(f'R2: {linreg.score(X, y):.2f}')
    print(f'Correlation: {np.corrcoef(a, b)[0, 1]:.2f}')
    
    y_fit = linreg.coef_[0] * a + linreg.intercept_
    
    ax.plot(a, y_fit, '-', color=palette[1], alpha=0.5)
    
    return ax
In [46]:
plot_absolute_error_vs_uncertainty(y_val, y_hat, first_order_std);
R2: 0.23
Correlation: 0.48
In [47]:
plot_absolute_error_vs_uncertainty(y_val, y_hat, second_order_std);
R2: 0.24
Correlation: 0.49
In [48]:
plot_absolute_error_vs_uncertainty(y_val, y_hat, first_order_std + second_order_std);
R2: 0.24
Correlation: 0.49
In [ ]:
 
In [ ]: