#!/usr/bin/env python # coding: utf-8 # # Stochastic Variational Inference # # Implementation of Stochastic Variational Inference (SVI) using [PyTorch](https://pytorch.org/), for the purpose of uncertainty quantification. # # We'll consider the [OLS Regression Challenge](https://data.world/nrippner/ols-regression-challenge), which aims at predicting cancer mortality rates for US counties. # # ## Data Dictionary # # * **TARGET_deathRate**: Dependent variable. Mean per capita (100,000) cancer mortalities (a) # * **avgAnnCount**: Mean number of reported cases of cancer diagnosed annually (a) # * **avgDeathsPerYear**: Mean number of reported mortalities due to cancer (a) # * **incidenceRate**: Mean per capita (100,000) cancer diagoses (a) # * **medianIncome**: Median income per county (b) # * **popEst2015**: Population of county (b) # * **povertyPercent**: Percent of populace in poverty (b) # * **studyPerCap**: Per capita number of cancer-related clinical trials per county (a) # * **binnedInc**: Median income per capita binned by decile (b) # * **MedianAge**: Median age of county residents (b) # * **MedianAgeMale**: Median age of male county residents (b) # * **MedianAgeFemale**: Median age of female county residents (b) # * **Geography**: County name (b) # * **AvgHouseholdSize**: Mean household size of county (b) # * **PercentMarried**: Percent of county residents who are married (b) # * **PctNoHS18_24**: Percent of county residents ages 18-24 highest education attained: less than high school (b) # * **PctHS18_24**: Percent of county residents ages 18-24 highest education attained: high school diploma (b) # * **PctSomeCol18_24**: Percent of county residents ages 18-24 highest education attained: some college (b) # * **PctBachDeg18_24**: Percent of county residents ages 18-24 highest education attained: bachelor's degree (b) # * **PctHS25_Over**: Percent of county residents ages 25 and over highest education attained: high school diploma (b) # * **PctBachDeg25_Over**: Percent of county residents ages 25 and over highest education attained: bachelor's degree (b) # * **PctEmployed16_Over**: Percent of county residents ages 16 and over employed (b) # * **PctUnemployed16_Over**: Percent of county residents ages 16 and over unemployed (b) # * **PctPrivateCoverage**: Percent of county residents with private health coverage (b) # * **PctPrivateCoverageAlone**: Percent of county residents with private health coverage alone (no public assistance) (b) # * **PctEmpPrivCoverage**: Percent of county residents with employee-provided private health coverage (b) # * **PctPublicCoverage**: Percent of county residents with government-provided health coverage (b) # * **PctPubliceCoverageAlone**: Percent of county residents with government-provided health coverage alone (b) # * **PctWhite**: Percent of county residents who identify as White (b) # * **PctBlack**: Percent of county residents who identify as Black (b) # * **PctAsian**: Percent of county residents who identify as Asian (b) # * **PctOtherRace**: Percent of county residents who identify in a category which is not White, Black, or Asian (b) # * **PctMarriedHouseholds**: Percent of married households (b) # * **BirthRate**: Number of live births relative to number of women in county (b) # # Notes: # * (a): years 2010-2016 # * (b): 2013 Census Estimates # In[1]: import os from os.path import join import pandas as pd import numpy as np import matplotlib.pyplot as plt import seaborn as sns import torch import torch.nn.functional as F from sklearn.metrics import r2_score from sklearn.model_selection import train_test_split from scipy.stats import binned_statistic cwd = os.getcwd() if cwd.endswith('notebook'): os.chdir('..') cwd = os.getcwd() # In[2]: sns.set(palette='colorblind', font_scale=1.3) palette = sns.color_palette() # In[3]: seed = 444 np.random.seed(seed); torch.manual_seed(seed); torch.set_default_dtype(torch.float64) # ## Dataset # # ### Load # In[4]: df = pd.read_csv(join(cwd, 'data/cancer_reg.csv')) df.head() # ### Describe # In[5]: df.describe() # ### Plot distribution of target variable # In[6]: _, ax = plt.subplots(1, 1, figsize=(10, 5)) df['TARGET_deathRate'].hist(bins=50, ax=ax); ax.set_title('Distribution of cancer death rate per 100,000 people'); ax.set_xlabel('Cancer death rate in county (per 100,000 people)'); ax.set_ylabel('Count'); # ### Define feature & target variables # In[7]: target = 'TARGET_deathRate' features = [ col for col in df.columns if col not in [ target, 'Geography', # Label describing the county - each row has a different one 'binnedInc', # Redundant with median income? 'PctSomeCol18_24', # contains null values - ignoring for now 'PctEmployed16_Over', # contains null values - ignoring for now 'PctPrivateCoverageAlone', # contains null values - ignoring for now ] ] print(len(features), 'features') # In[8]: x = df[features].values y = df[[target]].values print(x.shape, y.shape) # ## Define model # In[9]: class VariationalModel(torch.nn.Module): def __init__( self, n_inputs, n_hidden, priors, x_scaler, y_scaler, n_shared_layers=1, ): super().__init__() self.priors = priors self.x_scaler = x_scaler self.y_scaler = y_scaler self.jitter = 1e-6 # Shared layers self.shared_layers = [ torch.nn.Linear(n_inputs, n_hidden) ] if n_shared_layers > 1: for _ in range(n_shared_layers - 1): self.shared_layers.append( torch.nn.Linear(n_hidden, n_hidden) ) # Parameters for the mean self.mean_loc_hidden = torch.nn.Linear(n_hidden, n_hidden) self.mean_loc = torch.nn.Linear(n_hidden, 1) self.mean_scale_hidden = torch.nn.Linear(n_hidden, n_hidden) self.mean_scale = torch.nn.Linear(n_hidden, 1) # Parameters for the standard deviation self.std_loc_hidden = torch.nn.Linear(n_hidden, n_hidden) self.std_loc = torch.nn.Linear(n_hidden, 1) self.std_scale_hidden = torch.nn.Linear(n_hidden, n_hidden) self.std_scale = torch.nn.Linear(n_hidden, 1) def encoder(self, x): shared = self.x_scaler(x) for layer in self.shared_layers: shared = layer(shared) shared = F.relu(shared) #shared = self.dropout(shared) # Parametrization of the mean (normal) mean_loc = self.mean_loc_hidden(shared) mean_loc = F.relu(mean_loc) mean_loc = self.mean_loc(mean_loc) mean_scale = self.mean_scale_hidden(shared) mean_scale = F.relu(mean_scale) mean_scale = F.softplus(self.mean_scale(mean_scale)) + self.jitter # Parametrization of the standard deviation (log normal) std_loc = self.std_loc_hidden(shared) std_loc = F.relu(std_loc) std_loc = self.std_loc(std_loc) std_scale = self.std_scale_hidden(shared) std_scale = F.relu(std_scale) std_scale = F.softplus(self.std_scale(std_scale)) + self.jitter return ( torch.distributions.Normal(mean_loc, mean_scale), torch.distributions.LogNormal(std_loc, std_scale) ) def decoder(self, mean, std): return torch.distributions.Normal(mean, std) def forward(self, x, sample_shape=torch.Size([])): mean_dist, std_dist = self.encoder(x) mean = mean_dist.rsample(sample_shape) std = std_dist.rsample(sample_shape) + self.jitter y_hat = self.decoder(mean, std) kl_divergence = None if self.priors is not None: kl_divergence1 = torch.distributions.kl.kl_divergence(mean_dist, self.priors[0]) kl_divergence2 = torch.distributions.kl.kl_divergence(std_dist, self.priors[1]) kl_divergence = kl_divergence1 + kl_divergence2 return y_hat, kl_divergence def compute_loss(model, x, y, kl_reg=0.1): """ ELBO loss - https://en.wikipedia.org/wiki/Evidence_lower_bound """ y_scaled = model.y_scaler(y) y_hat, kl_divergence = model(x) neg_log_likelihood = -y_hat.log_prob(y_scaled) if kl_divergence is not None: neg_log_likelihood += kl_reg * kl_divergence return torch.mean(neg_log_likelihood) def compute_rmse(model, x_test, y_test): model.eval() y_hat, _ = model(x_test) pred = model.y_scaler.inverse_transform(y_hat.mean) return torch.sqrt(torch.mean((pred - y_test)**2)) def train_one_step(model, optimizer, x_batch, y_batch): model.train() optimizer.zero_grad() loss = compute_loss(model, x_batch, y_batch) loss.backward() optimizer.step() return loss # In[10]: def train(model, optimizer, x_train, x_val, y_train, y_val, n_epochs, batch_size, scheduler=None, print_every=10): train_losses, val_losses = [], [] for epoch in range(n_epochs): batch_indices = sample_batch_indices(x_train, y_train, batch_size) batch_losses_t, batch_losses_v, batch_rmse_v = [], [], [] for batch_ix in batch_indices: b_train_loss = train_one_step(model, optimizer, x_train[batch_ix], y_train[batch_ix]) model.eval() b_val_loss = compute_loss(model, x_val, y_val) b_val_rmse = compute_rmse(model, x_val, y_val) batch_losses_t.append(b_train_loss.detach().numpy()) batch_losses_v.append(b_val_loss.detach().numpy()) batch_rmse_v.append(b_val_rmse.detach().numpy()) if scheduler is not None: scheduler.step() train_loss = np.mean(batch_losses_t) val_loss = np.mean(batch_losses_v) val_rmse = np.mean(batch_rmse_v) train_losses.append(train_loss) val_losses.append(val_loss) if epoch == 0 or (epoch + 1) % print_every == 0: print(f'Epoch {epoch+1} | Validation loss = {val_loss:.4f} | Validation RMSE = {val_rmse:.4f}') _, ax = plt.subplots(1, 1, figsize=(12, 6)) ax.plot(range(1, n_epochs + 1), train_losses, label='Train loss') ax.plot(range(1, n_epochs + 1), val_losses, label='Validation loss') ax.set_xlabel('Epoch') ax.set_ylabel('Loss') ax.set_title('Training Overview') ax.legend() return train_losses, val_losses def sample_batch_indices(x, y, batch_size, rs=None): if rs is None: rs = np.random.RandomState() train_ix = np.arange(len(x)) rs.shuffle(train_ix) n_batches = int(np.ceil(len(x) / batch_size)) batch_indices = [] for i in range(n_batches): start = i * batch_size end = start + batch_size batch_indices.append( train_ix[start:end].tolist() ) return batch_indices # In[11]: class StandardScaler(object): """ Standardize data by removing the mean and scaling to unit variance. """ def __init__(self): self.mean = None self.scale = None def fit(self, sample): self.mean = sample.mean(0, keepdim=True) self.scale = sample.std(0, unbiased=False, keepdim=True) return self def __call__(self, sample): return self.transform(sample) def transform(self, sample): return (sample - self.mean) / self.scale def inverse_transform(self, sample): return sample * self.scale + self.mean # ## Split between training and validation sets # In[12]: def compute_train_test_split(x, y, test_size): x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=test_size) return ( torch.from_numpy(x_train), torch.from_numpy(x_test), torch.from_numpy(y_train), torch.from_numpy(y_test), ) # In[13]: x_train, x_val, y_train, y_val = compute_train_test_split(x, y, test_size=0.2) print(x_train.shape, y_train.shape) print(x_val.shape, y_val.shape) # ## Train model # In[14]: x_scaler = StandardScaler().fit(torch.from_numpy(x)) y_scaler = StandardScaler().fit(torch.from_numpy(y)) # In[15]: def lr_steep_schedule(epoch): threshold = 5 if epoch < threshold: return 1.0 multiplier = 2 / 10 step_change = (epoch - threshold) // 300 return multiplier * 0.5 ** (step_change) def lr_simple_schedule(epoch): return 0.5 ** (epoch // 200) # In[16]: def make_priors(y_train, y_scaler): std = torch.Tensor((0.5,)) scale = torch.Tensor((0.2,)) std_loc = torch.log(std**2 / torch.sqrt(std**2 + scale**2)) std_scale = torch.sqrt(torch.log(1 + scale**2 / std**2)) return [ torch.distributions.Normal(0, 1), torch.distributions.LogNormal(std_loc, std_scale), ] # In[17]: get_ipython().run_cell_magic('time', '', "\nlearning_rate = 1e-3\nmomentum = 0.9\nweight_decay = 1e-4\n\nn_epochs = 400\nbatch_size = 64\nprint_every = 50\n\nn_hidden = 100\nn_shared_layers = 1\n\npriors = make_priors(y_train, y_scaler)\n\nmodel = VariationalModel(\n n_inputs=x.shape[1],\n n_hidden=n_hidden,\n priors=priors,\n n_shared_layers=n_shared_layers,\n x_scaler=x_scaler, \n y_scaler=y_scaler,\n)\n\npytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\nprint(f'{pytorch_total_params:,} trainable parameters')\nprint()\n\noptimizer = torch.optim.SGD(\n model.parameters(), \n lr=learning_rate, \n momentum=momentum, \n nesterov=True,\n weight_decay=weight_decay,\n)\nscheduler = torch.optim.lr_scheduler.LambdaLR(\n optimizer, \n lr_lambda=lr_steep_schedule,\n)\n\ntrain_losses, val_losses = train(\n model, \n optimizer, \n x_train, \n x_val, \n y_train, \n y_val, \n n_epochs=n_epochs, \n batch_size=batch_size, \n scheduler=scheduler, \n print_every=print_every,\n)\n") # ### Validation # In[18]: y_dist, _ = model(x_val) y_hat = model.y_scaler.inverse_transform(y_dist.mean) # In[44]: val_rmse = float(compute_rmse(model, x_val, y_val).detach().numpy()) print(f'Validation RMSE = {val_rmse:.2f}') # In[20]: val_r2 = r2_score( y_val.detach().numpy().flatten(), y_hat.detach().numpy().flatten(), ) print(f'Validation $R^2$ = {val_r2:.2f}') # In[21]: def plot_results(y_true, y_pred): _, ax = plt.subplots(1, 1, figsize=(7, 7)) palette = sns.color_palette() min_value = min(np.amin(y_true), np.amin(y_pred)) max_value = max(np.amax(y_true), np.amax(y_pred)) y_mid = np.linspace(min_value, max_value) ax.plot(y_mid, y_mid, '--', color=palette[1]) ax.scatter(y_true, y_pred, color=palette[0], alpha=0.5); return ax # In[39]: ax = plot_results( y_val.detach().numpy().flatten(), y_hat.detach().numpy().flatten(), ); ax.text(270, 120, f'$R^2 = {val_r2:.2f}$') ax.text(270, 100, f'$RMSE = {val_rmse:.2f}$') ax.set_xlabel('Actuals'); ax.set_ylabel('Predictions'); ax.set_title('Regression results on validation set'); # ## Uncertainty # In[23]: def make_predictions(model, x): mean_dist, std_dist = model.encoder(x) inv_tr = model.y_scaler.inverse_transform y_hat = inv_tr(mean_dist.mean) # Recover standard deviation's original scale std_from_mean_dist = inv_tr(mean_dist.mean + mean_dist.stddev) - y_hat std_from_std_dist = inv_tr(mean_dist.mean + std_dist.mean + std_dist.stddev) - y_hat return y_hat, std_from_mean_dist, std_from_std_dist # In[24]: y_hat, std_from_mean_dist , std_from_std_dist = make_predictions(model, x_val) std = std_from_mean_dist + std_from_std_dist # In[25]: plt.hist(y_hat.detach().numpy().flatten(), bins=50); # In[26]: plt.hist(std_from_mean_dist.detach().numpy().flatten(), bins=50); # In[27]: plt.hist(std_from_std_dist.detach().numpy().flatten(), bins=50); # ### Plot absolute error against uncertainty # In[28]: absolute_errors = torch.abs(y_hat - y_val).detach().numpy().flatten() stds = std.detach().numpy().flatten() df = pd.DataFrame.from_dict({ 'error': absolute_errors, 'uncertainty': stds }) df.corr() # In[29]: def plot_absolute_error_vs_uncertainty(y_val, y_hat, std, binned): absolute_errors = torch.abs(y_hat - y_val).detach().numpy().flatten() stds = std.detach().numpy().flatten() if binned: ret = binned_statistic(absolute_errors, stds, bins=100) a = ret.bin_edges[1:] b = ret.statistic alpha = 1.0 else: a = absolute_errors b = stds alpha = 0.3 _, ax = plt.subplots(1, 1, figsize=(8, 8)) ax.scatter(a, b, alpha=alpha) max_val = max(ax.get_xlim()[-1], ax.get_ylim()[-1]) x = np.linspace(0, max_val) ax.plot(x, x, '--', color='#ccc', alpha=0.8) return ax # In[30]: ax = plot_absolute_error_vs_uncertainty(y_val, y_hat, std, binned=True) ax.set_title('Absolute error vs uncertainty (binned & averaged)'); ax.set_xlabel('Absolute error'); ax.set_ylabel('Standard deviations'); # In[31]: ax = plot_absolute_error_vs_uncertainty(y_val, y_hat, std, binned=False) ax.set_title('Absolute error vs uncertainty'); ax.set_xlabel('Absolute error'); ax.set_ylabel('Standard deviations'); # In[32]: def plot_absolute_error_vs_uncertainty_normalized(y_val, y_hat, std, binned): absolute_errors = (torch.abs(y_hat - y_val) / y_val).detach().numpy().flatten() stds = (std / y_val).detach().numpy().flatten() if binned: ret = binned_statistic(absolute_errors, stds, bins=100) a = ret.bin_edges[1:] b = ret.statistic alpha = 1.0 else: a = absolute_errors b = stds alpha = 0.3 _, ax = plt.subplots(1, 1, figsize=(8, 8)) ax.scatter(a, b, alpha=alpha) max_val = max(ax.get_xlim()[-1], ax.get_ylim()[-1]) x = np.linspace(0, max_val) ax.plot(x, x, '--', color='#ccc', alpha=0.8) return ax # In[33]: ax = plot_absolute_error_vs_uncertainty_normalized(y_val, y_hat, std, binned=True) ax.set_title('Absolute error vs uncertainty (normalized)'); ax.set_xlabel('Absolute error / actual'); ax.set_ylabel('Standard deviations / actual'); # In[34]: ax = plot_absolute_error_vs_uncertainty_normalized(y_val, y_hat, std, binned=False) ax.set_title('Absolute error vs uncertainty'); ax.set_xlabel('Absolute error / actual'); ax.set_ylabel('Standard deviations / actual'); # ### Uncertainty on an input not yet seen & very different from the rest of the dataset # # We'll consider an input composed of the maximum of all features. # In[35]: random_input_np = np.zeros((1, x_val.shape[1])) for i in range(x_val.shape[1]): fn = torch.amax random_input_np[0, i] = fn(x_val[:, i]) * 2 random_input = torch.from_numpy(random_input_np) random_input.shape # In[36]: _, std_epistemic, std_aleatoric = make_predictions(model, random_input) uncertainty = float((std_epistemic + std_aleatoric).detach().numpy()) print(f'Uncertainty on made up input: {uncertainty:.2f}') # Uncertainty is high, which is what we want! # In[37]: _, std_epistemic, std_aleatoric = make_predictions(model, random_input) print(f'Aleatoric uncertainty: {float(std_aleatoric.detach().numpy()):.2f}') print(f'Epistemic uncertainty: {float(std_epistemic.detach().numpy()):.2f}') # In[ ]: