#!/usr/bin/env python # coding: utf-8 # # Parametrising a distribution with a neural network # # Parametrization of a [PyTorch](https://pytorch.org/) distribution with a neural network, for the purpose of uncertainty quantification. # # We'll consider the [OLS Regression Challenge](https://data.world/nrippner/ols-regression-challenge), which aims at predicting cancer mortality rates for US counties. # # ## Data Dictionary # # * **TARGET_deathRate**: Dependent variable. Mean per capita (100,000) cancer mortalities (a) # * **avgAnnCount**: Mean number of reported cases of cancer diagnosed annually (a) # * **avgDeathsPerYear**: Mean number of reported mortalities due to cancer (a) # * **incidenceRate**: Mean per capita (100,000) cancer diagoses (a) # * **medianIncome**: Median income per county (b) # * **popEst2015**: Population of county (b) # * **povertyPercent**: Percent of populace in poverty (b) # * **studyPerCap**: Per capita number of cancer-related clinical trials per county (a) # * **binnedInc**: Median income per capita binned by decile (b) # * **MedianAge**: Median age of county residents (b) # * **MedianAgeMale**: Median age of male county residents (b) # * **MedianAgeFemale**: Median age of female county residents (b) # * **Geography**: County name (b) # * **AvgHouseholdSize**: Mean household size of county (b) # * **PercentMarried**: Percent of county residents who are married (b) # * **PctNoHS18_24**: Percent of county residents ages 18-24 highest education attained: less than high school (b) # * **PctHS18_24**: Percent of county residents ages 18-24 highest education attained: high school diploma (b) # * **PctSomeCol18_24**: Percent of county residents ages 18-24 highest education attained: some college (b) # * **PctBachDeg18_24**: Percent of county residents ages 18-24 highest education attained: bachelor's degree (b) # * **PctHS25_Over**: Percent of county residents ages 25 and over highest education attained: high school diploma (b) # * **PctBachDeg25_Over**: Percent of county residents ages 25 and over highest education attained: bachelor's degree (b) # * **PctEmployed16_Over**: Percent of county residents ages 16 and over employed (b) # * **PctUnemployed16_Over**: Percent of county residents ages 16 and over unemployed (b) # * **PctPrivateCoverage**: Percent of county residents with private health coverage (b) # * **PctPrivateCoverageAlone**: Percent of county residents with private health coverage alone (no public assistance) (b) # * **PctEmpPrivCoverage**: Percent of county residents with employee-provided private health coverage (b) # * **PctPublicCoverage**: Percent of county residents with government-provided health coverage (b) # * **PctPubliceCoverageAlone**: Percent of county residents with government-provided health coverage alone (b) # * **PctWhite**: Percent of county residents who identify as White (b) # * **PctBlack**: Percent of county residents who identify as Black (b) # * **PctAsian**: Percent of county residents who identify as Asian (b) # * **PctOtherRace**: Percent of county residents who identify in a category which is not White, Black, or Asian (b) # * **PctMarriedHouseholds**: Percent of married households (b) # * **BirthRate**: Number of live births relative to number of women in county (b) # # Notes: # * (a): years 2010-2016 # * (b): 2013 Census Estimates # In[94]: import os from os.path import join import pandas as pd import numpy as np import matplotlib.pyplot as plt import seaborn as sns import torch import torch.nn.functional as F from sklearn.metrics import r2_score from sklearn.linear_model import LinearRegression from sklearn.model_selection import train_test_split from scipy.stats import binned_statistic import statsmodels.api as sm cwd = os.getcwd() if cwd.endswith('notebook'): os.chdir('..') cwd = os.getcwd() # In[267]: sns.set(palette='colorblind', font_scale=1.3) palette = sns.color_palette() # In[3]: seed = 456 np.random.seed(seed); torch.manual_seed(seed); torch.set_default_dtype(torch.float64) # ## Dataset # In[4]: df = pd.read_csv(join(cwd, 'data/cancer_reg.csv')) df.head() # In[5]: df.describe() # ### Plot distribution of target variable # In[6]: _, ax = plt.subplots(1, 1, figsize=(10, 5)) df['TARGET_deathRate'].hist(bins=50, ax=ax); ax.set_title('Distribution of cancer death rate per 100,000 people'); ax.set_xlabel('Cancer death rate in county (per 100,000 people)'); ax.set_ylabel('Count'); # ### Define feature & target variables # In[7]: target = 'TARGET_deathRate' features = [ col for col in df.columns if col not in [ target, 'Geography', # Label describing the county - each row has a different one 'binnedInc', # Redundant with median income? 'PctSomeCol18_24', # contains null values - ignoring for now 'PctEmployed16_Over', # contains null values - ignoring for now 'PctPrivateCoverageAlone', # contains null values - ignoring for now ] ] print(len(features), 'features') # In[8]: x = df[features].values y = df[[target]].values print(x.shape, y.shape) # ## Define model # In[9]: class DeepNormalModel(torch.nn.Module): def __init__( self, n_inputs, n_hidden, x_scaler, y_scaler, ): super().__init__() self.x_scaler = x_scaler self.y_scaler = y_scaler self.jitter = 1e-6 self.shared = torch.nn.Linear(n_inputs, n_hidden) self.mean_hidden = torch.nn.Linear(n_hidden, n_hidden) self.mean_linear = torch.nn.Linear(n_hidden, 1) self.std_hidden = torch.nn.Linear(n_hidden, n_hidden) self.std_linear = torch.nn.Linear(n_hidden, 1) self.dropout = torch.nn.Dropout() def forward(self, x): # Normalization shared = self.x_scaler(x) # Shared layer shared = self.shared(shared) shared = F.relu(shared) shared = self.dropout(shared) # Parametrization of the mean mean_hidden = self.mean_hidden(shared) mean_hidden = F.relu(mean_hidden) mean_hidden = self.dropout(mean_hidden) mean = self.mean_linear(mean_hidden) # Parametrization fo the standard deviation std_hidden = self.std_hidden(shared) std_hidden = F.relu(std_hidden) std_hidden = self.dropout(std_hidden) std = F.softplus(self.std_linear(std_hidden)) + self.jitter return torch.distributions.Normal(mean, std) # In[10]: def compute_loss(model, x, y, kl_reg=0.1): y_scaled = model.y_scaler(y) y_hat = model(x) neg_log_likelihood = -y_hat.log_prob(y_scaled) return torch.mean(neg_log_likelihood) def compute_rmse(model, x_test, y_test): model.eval() y_hat = model(x_test) pred = model.y_scaler.inverse_transform(y_hat.mean) return torch.sqrt(torch.mean((pred - y_test)**2)) def train_one_step(model, optimizer, x_batch, y_batch): model.train() optimizer.zero_grad() loss = compute_loss(model, x_batch, y_batch) loss.backward() optimizer.step() return loss # In[11]: def train(model, optimizer, x_train, x_val, y_train, y_val, n_epochs, batch_size, scheduler=None, print_every=10): train_losses, val_losses = [], [] for epoch in range(n_epochs): batch_indices = sample_batch_indices(x_train, y_train, batch_size) batch_losses_t, batch_losses_v, batch_rmse_v = [], [], [] for batch_ix in batch_indices: b_train_loss = train_one_step(model, optimizer, x_train[batch_ix], y_train[batch_ix]) model.eval() b_val_loss = compute_loss(model, x_val, y_val) b_val_rmse = compute_rmse(model, x_val, y_val) batch_losses_t.append(b_train_loss.detach().numpy()) batch_losses_v.append(b_val_loss.detach().numpy()) batch_rmse_v.append(b_val_rmse.detach().numpy()) if scheduler is not None: scheduler.step() train_loss = np.mean(batch_losses_t) val_loss = np.mean(batch_losses_v) val_rmse = np.mean(batch_rmse_v) train_losses.append(train_loss) val_losses.append(val_loss) if epoch == 0 or (epoch + 1) % print_every == 0: print(f'Epoch {epoch+1} | Validation loss = {val_loss:.4f} | Validation RMSE = {val_rmse:.4f}') _, ax = plt.subplots(1, 1, figsize=(12, 6)) ax.plot(range(1, n_epochs + 1), train_losses, label='Train loss') ax.plot(range(1, n_epochs + 1), val_losses, label='Validation loss') ax.set_xlabel('Epoch') ax.set_ylabel('Loss') ax.set_title('Training Overview') ax.legend() return train_losses, val_losses def sample_batch_indices(x, y, batch_size, rs=None): if rs is None: rs = np.random.RandomState() train_ix = np.arange(len(x)) rs.shuffle(train_ix) n_batches = int(np.ceil(len(x) / batch_size)) batch_indices = [] for i in range(n_batches): start = i * batch_size end = start + batch_size batch_indices.append( train_ix[start:end].tolist() ) return batch_indices # In[12]: class StandardScaler(object): """ Standardize data by removing the mean and scaling to unit variance. """ def __init__(self): self.mean = None self.scale = None def fit(self, sample): self.mean = sample.mean(0, keepdim=True) self.scale = sample.std(0, unbiased=False, keepdim=True) return self def __call__(self, sample): return self.transform(sample) def transform(self, sample): return (sample - self.mean) / self.scale def inverse_transform(self, sample): return sample * self.scale + self.mean # ## Split between training and validation sets # In[13]: def compute_train_test_split(x, y, test_size): x_train, x_test, y_train, y_test, train_ix, test_ix = train_test_split( x, y, list(range(len(x))), test_size=test_size, ) return ( torch.from_numpy(x_train), torch.from_numpy(x_test), torch.from_numpy(y_train), torch.from_numpy(y_test), train_ix, test_ix, ) # In[14]: x_train, x_val, y_train, y_val, train_ix, test_ix = compute_train_test_split(x, y, test_size=0.2) print(x_train.shape, y_train.shape, len(train_ix)) print(x_val.shape, y_val.shape, len(test_ix)) # ## Train model # In[15]: x_scaler = StandardScaler().fit(torch.from_numpy(x)) y_scaler = StandardScaler().fit(torch.from_numpy(y)) # In[16]: get_ipython().run_cell_magic('time', '', "\nlearning_rate = 1e-3\nmomentum = 0.9\nweight_decay = 1e-5\n\nn_epochs = 300\nbatch_size = 64\nprint_every = 50\n\nn_hidden = 100\n\nmodel = DeepNormalModel(\n n_inputs=x.shape[1],\n n_hidden=n_hidden,\n x_scaler=x_scaler, \n y_scaler=y_scaler,\n)\n\npytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\nprint(f'{pytorch_total_params:,} trainable parameters')\nprint()\n\noptimizer = torch.optim.SGD(\n model.parameters(), \n lr=learning_rate, \n momentum=momentum, \n nesterov=True,\n weight_decay=weight_decay,\n)\nscheduler = None\n\ntrain_losses, val_losses = train(\n model, \n optimizer, \n x_train, \n x_val, \n y_train, \n y_val, \n n_epochs=n_epochs, \n batch_size=batch_size, \n scheduler=scheduler, \n print_every=print_every,\n)\n") # ### Validation # In[17]: y_dist = model(x_val) y_hat = model.y_scaler.inverse_transform(y_dist.mean) # In[18]: val_rmse = float(compute_rmse(model, x_val, y_val).detach().numpy()) print(f'Validation RMSE = {val_rmse:.2f}') # In[19]: val_r2 = r2_score( y_val.detach().numpy().flatten(), y_hat.detach().numpy().flatten(), ) print(f'Validation $R^2$ = {val_r2:.2f}') # In[20]: def plot_results(y_true, y_pred): f, ax = plt.subplots(1, 1, figsize=(7, 7)) palette = sns.color_palette() min_value = min(np.amin(y_true), np.amin(y_pred)) max_value = max(np.amax(y_true), np.amax(y_pred)) y_mid = np.linspace(min_value, max_value) ax.plot(y_mid, y_mid, '--', color=palette[1]) ax.scatter(y_true, y_pred, color=palette[0], alpha=0.5); return f, ax # In[21]: f, ax = plot_results( y_val.detach().numpy().flatten(), y_hat.detach().numpy().flatten(), ); ax.text(225, 95, f'$R^2 = {val_r2:.2f}$') ax.text(225, 80, f'$RMSE = {val_rmse:.2f}$') ax.set_xlabel('Actuals'); ax.set_ylabel('Predictions'); ax.set_title('Regression results on validation set'); # ## Uncertainty # In[22]: def make_predictions(model, x): dist = model(x) inv_tr = model.y_scaler.inverse_transform y_hat = inv_tr(dist.mean) # Recover standard deviation's original scale std = inv_tr(dist.mean + dist.stddev) - y_hat return y_hat, std # In[23]: y_hat, std = make_predictions(model, x_val) # In[337]: f, axes = plt.subplots(1, 2, figsize=(18, 6)) ax1, ax2 = axes ax1.hist(y_hat.detach().numpy().flatten(), bins=50); ax2.hist(std.detach().numpy().flatten(), bins=50, color=palette[1]); ax1.set_xlabel('Cancer mortality rate (mean)'); ax1.set_ylabel('Instance count'); ax1.set_title('Distribution of mean $\mu$ on validation set'); ax2.set_xlabel('Cancer mortality rate (stddev)'); ax2.set_ylabel('Instance count'); ax2.set_title('Distribution of stddev $\sigma$ on validation set'); # In[240]: low_ix = [i for i, v in enumerate(y_hat.detach().numpy().flatten()) if float(v) < 140][0] average_ix = [i for i, v in enumerate(y_hat.detach().numpy().flatten()) if 170 < float(v) < 190][0] high_ix = [i for i, v in enumerate(y_hat.detach().numpy().flatten()) if float(v) > 220][0] # In[265]: def plot_normal_distributions(y_hat, std, indices): f, ax = plt.subplots(1, 1, figsize=(12, 6)) normal_fn = lambda z, m, s: (1/(s * np.sqrt(2*np.pi))) * np.exp(-0.5*((z - m) / s)**2) x = np.linspace(80, 300, 300) for ix in indices: mu = y_hat.detach().numpy().flatten()[ix] sigma = std.detach().numpy().flatten()[ix] ax.plot(x, normal_fn(x, mu, sigma), label=rf'$\mu = {mu:.0f}$, $\sigma = {sigma:.1f}$') ax.set_xlabel('Cancer mortality rate ($y$)'); ax.set_ylabel(r'$P(y \mid \mu, \sigma)$'); ax.set_title('PDF of three instances of the validation set'); ax.legend(loc='upper left'); return f, ax # In[338]: f, ax = plot_normal_distributions(y_hat, std, [low_ix, average_ix, high_ix]) # ### Plot absolute error against uncertainty # In[291]: def plot_absolute_error_vs_uncertainty(y_val, y_hat, std): absolute_errors = 100 * (torch.abs(y_hat - y_val) / y_val).detach().numpy().flatten() stds = std.detach().numpy().flatten() a = absolute_errors b = stds alpha = 0.5 f, ax = plt.subplots(1, 1, figsize=(7, 7)) ax.scatter(a, b, alpha=alpha) X = a.reshape((-1, 1)) y = b linreg = LinearRegression().fit(X, y) r2 = linreg.score(X, y) corr = np.corrcoef(a, b)[0, 1] print(f'R2: {r2:.2f}') print(f'Correlation: {corr:.2f}') y_fit = linreg.coef_[0] * a + linreg.intercept_ ax.plot(a, y_fit, '-', color=palette[1], alpha=0.5) ax.text(56, 7.5, f'Corr = {corr:.2f}') return f, ax # In[292]: f, ax = plot_absolute_error_vs_uncertainty(y_val, y_hat, std) ax.set_title('Error vs uncertainty'); ax.set_xlabel('Absolute error / actual (%)'); ax.set_ylabel('Standard deviation'); # ### Uncertainty on an input not yet seen & very different from the rest of the dataset # # We'll consider an input made of the maximum of all features. # In[28]: max_input_np = np.zeros((1, x_val.shape[1])) for i in range(x_val.shape[1]): max_input_np[0, i] = torch.amax(x_val[:, i]) * 2 max_input = torch.from_numpy(max_input_np) max_input.shape # In[167]: r_y_hat, r_first_order_std = make_predictions(model, max_input) print(f'First order uncertainty: {float((r_first_order_std / r_y_hat).detach().numpy()):.2f}') # ## Unexpectedness # # An instance's expectedness can be measured by its average feature quantile minus the median 0.5. # # We can then estimate the relation between uncertainty and expectedness. # In[30]: def set_feature_quantile(df, features): series = [] quantiles = np.arange(0., 1., 0.01) for q in quantiles: series.append(df[features].quantile(q)) def apply_quantile(row): v = np.zeros((len(features))) for i, f in enumerate(features): value = row[f] for j, q in enumerate(quantiles): if value <= series[j][f]: v[i] = q break return v.mean() df['quantile'] = df.apply(apply_quantile, axis=1) # In[31]: set_feature_quantile(df, features) # In[32]: df['quantile'].hist(bins=20); # In[331]: def plot_expectedness_vs_uncertainty(df, test_ix, y_val, std): quantiles = df['quantile'].iloc[test_ix].values expectedness = 100 * np.abs(quantiles - 0.5) stds = std.detach().numpy().flatten() f, ax = plt.subplots(1, 1, figsize=(7, 7)) ax.scatter(expectedness, stds, alpha=0.5) X = sm.add_constant(expectedness.reshape((-1, 1))) y = stds.reshape((-1, 1)) linreg = sm.OLS(y, X).fit() y_fit = linreg.predict(X).flatten() corr = np.corrcoef(expectedness, stds)[0, 1] print(f'Correlation: {corr:.2f}') print(linreg.summary()) ax.plot(expectedness, y_fit, '-', color=palette[1], alpha=0.5) ax.text(27, 7.5, f'Corr = {corr:.2f}') return f, ax # In[332]: f, ax = plot_expectedness_vs_uncertainty(df, test_ix, y_val, std) ax.set_xlabel('Unexpectedness (%)') ax.set_ylabel('Standard deviation') ax.set_title('Unexpectedness vs uncertainty'); # In[ ]: