#!/usr/bin/env python # coding: utf-8 # # Bayesian Regression Using NumPyro # # In this tutorial, we will explore how to do bayesian regression in NumPyro, using a simple example adapted from Statistical Rethinking [[1](#References)]. In particular, we would like to explore the following: # # - Write a simple model using the `sample` NumPyro primitive. # - Run inference using MCMC in NumPyro, in particular, using the No U-Turn Sampler (NUTS) to get a posterior distribution over our regression parameters of interest. # - Learn about inference utilities such as `Predictive` and `log_likelihood`. # - Learn how we can use effect-handlers in NumPyro to generate execution traces from the model, condition on sample statements, seed models with RNG seeds, etc., and use this to implement various utilities that will be useful for MCMC. e.g. computing model log likelihood, generating empirical distribution over the posterior predictive, etc. # # ## Tutorial Outline: # # 1. [Dataset](#Dataset) # 2. [Regression Model to Predict Divorce Rate](#Regression-Model-to-Predict-Divorce-Rate) # - [Model-1: Predictor-Marriage Rate](#Model-1:-Predictor---Marriage-Rate) # - [Posterior Distribution over the Regression Parameters](#Posterior-Distribution-over-the-Regression-Parameters) # - [Prior Predictive Distribution](#Prior-Predictive-Distribution) # - [Posterior Predictive Distribution](#Posterior-Predictive-Distribution) # - [Predictive Utility With Effect Handlers](#Predictive-Utility-With-Effect-Handlers) # - [Posterior Predictive Density](#Posterior-Predictive-Density) # - [Model-2: Predictor-Median Age of Marriage](#Model-2:-Predictor---Median-Age-of-Marriage) # - [Model-3: Predictor-Marriage Rate and Median Age of Marriage](#Model-3:-Predictor---Marriage-Rate-and-Median-Age-of-Marriage) # - [Divorce Rate Residuals by State](#Divorce-Rate-Residuals-by-State) # 3. [Regression Model with Measurement Error](#Regression-Model-with-Measurement-Error) # - [Effect of Incorporating Measurement Noise on Residuals](#Effect-of-Incorporating-Measurement-Noise-on-Residuals) # 4. [References](#References) # In[1]: get_ipython().system('pip install -q numpyro@git+https://github.com/pyro-ppl/numpyro') # In[2]: import os from IPython.display import set_matplotlib_formats import matplotlib.pyplot as plt import pandas as pd import seaborn as sns from jax import random, vmap import jax.numpy as jnp from jax.scipy.special import logsumexp import numpyro from numpyro import handlers from numpyro.diagnostics import hpdi import numpyro.distributions as dist from numpyro.infer import MCMC, NUTS plt.style.use("bmh") if "NUMPYRO_SPHINXBUILD" in os.environ: set_matplotlib_formats("svg") assert numpyro.__version__.startswith("0.14.0") # ## Dataset # # For this example, we will use the `WaffleDivorce` dataset from Chapter 05, Statistical Rethinking [[1](#References)]. The dataset contains divorce rates in each of the 50 states in the USA, along with predictors such as population, median age of marriage, whether it is a Southern state and, curiously, number of Waffle Houses. # In[3]: DATASET_URL = "https://raw.githubusercontent.com/rmcelreath/rethinking/master/data/WaffleDivorce.csv" dset = pd.read_csv(DATASET_URL, sep=";") dset # Let us plot the pair-wise relationship amongst the main variables in the dataset, using `seaborn.pairplot`. # In[4]: vars = [ "Population", "MedianAgeMarriage", "Marriage", "WaffleHouses", "South", "Divorce", ] sns.pairplot(dset, x_vars=vars, y_vars=vars, palette="husl"); # From the plots above, we can clearly observe that there is a relationship between divorce rates and marriage rates in a state (as might be expected), and also between divorce rates and median age of marriage. # # There is also a weak relationship between number of Waffle Houses and divorce rates, which is not obvious from the plot above, but will be clearer if we regress `Divorce` against `WaffleHouse` and plot the results. # In[5]: sns.regplot(x="WaffleHouses", y="Divorce", data=dset); # This is an example of a spurious association. We do not expect the number of Waffle Houses in a state to affect the divorce rate, but it is likely correlated with other factors that have an effect on the divorce rate. We will not delve into this spurious association in this tutorial, but the interested reader is encouraged to read Chapters 5 and 6 of [[1](#References)] which explores the problem of causal association in the presence of multiple predictors. # # For simplicity, we will primarily focus on marriage rate and the median age of marriage as our predictors for divorce rate throughout the remaining tutorial. # ## Regression Model to Predict Divorce Rate # # Let us now write a regressionn model in *NumPyro* to predict the divorce rate as a linear function of marriage rate and median age of marriage in each of the states. # # First, note that our predictor variables have somewhat different scales. It is a good practice to standardize our predictors and response variables to mean `0` and standard deviation `1`, which should result in [faster inference](https://mc-stan.org/docs/2_19/stan-users-guide/standardizing-predictors-and-outputs.html). # In[6]: def standardize(x): return (x - x.mean()) / x.std() dset["AgeScaled"] = dset.MedianAgeMarriage.pipe(standardize) dset["MarriageScaled"] = dset.Marriage.pipe(standardize) dset["DivorceScaled"] = dset.Divorce.pipe(standardize) # We write the NumPyro model as follows. While the code should largely be self-explanatory, take note of the following: # # - In NumPyro, *model* code is any Python callable which can optionally accept additional arguments and keywords. For HMC which we will be using for this tutorial, these arguments and keywords remain static during inference, but we can reuse the same model to generate [predictions](#Posterior-Predictive-Distribution) on new data. # - In addition to regular Python statements, the model code also contains primitives like `sample`. These primitives can be interpreted with various side-effects using effect handlers. For more on effect handlers, refer to [[3](#References)], [[4](#References)]. For now, just remember that a `sample` statement makes this a stochastic function that samples some latent parameters from a *prior distribution*. Our goal is to infer the *posterior distribution* of these parameters conditioned on observed data. # - The reason why we have kept our predictors as optional keyword arguments is to be able to reuse the same model as we vary the set of predictors. Likewise, the reason why the response variable is optional is that we would like to reuse this model to sample from the posterior predictive distribution. See the [section](#Posterior-Predictive-Distribution) on plotting the posterior predictive distribution, as an example. # In[7]: def model(marriage=None, age=None, divorce=None): a = numpyro.sample("a", dist.Normal(0.0, 0.2)) M, A = 0.0, 0.0 if marriage is not None: bM = numpyro.sample("bM", dist.Normal(0.0, 0.5)) M = bM * marriage if age is not None: bA = numpyro.sample("bA", dist.Normal(0.0, 0.5)) A = bA * age sigma = numpyro.sample("sigma", dist.Exponential(1.0)) mu = a + M + A numpyro.sample("obs", dist.Normal(mu, sigma), obs=divorce) # ### Model 1: Predictor - Marriage Rate # # We first try to model the divorce rate as depending on a single variable, marriage rate. As mentioned above, we can use the same `model` code as earlier, but only pass values for `marriage` and `divorce` keyword arguments. We will use the No U-Turn Sampler (see [[5](#References)] for more details on the NUTS algorithm) to run inference on this simple model. # # The Hamiltonian Monte Carlo (or, the NUTS) implementation in NumPyro takes in a potential energy function. This is the negative log joint density for the model. Therefore, for our model description above, we need to construct a function which given the parameter values returns the potential energy (or negative log joint density). Additionally, the verlet integrator in HMC (or, NUTS) returns sample values simulated using Hamiltonian dynamics in the unconstrained space. As such, continuous variables with bounded support need to be transformed into unconstrained space using bijective transforms. We also need to transform these samples back to their constrained support before returning these values to the user. Thankfully, this is handled on the backend for us, within a convenience class for doing [MCMC inference](https://numpyro.readthedocs.io/en/latest/mcmc.html#numpyro.mcmc.MCMC) that has the following methods: # # - `run(...)`: runs warmup, adapts steps size and mass matrix, and does sampling using the sample from the warmup phase. # - `print_summary()`: print diagnostic information like quantiles, effective sample size, and the Gelman-Rubin diagnostic. # - `get_samples()`: gets samples from the posterior distribution. # # Note the following: # # - JAX uses functional PRNGs. Unlike other languages / frameworks which maintain a global random state, in JAX, every call to a sampler requires an [explicit PRNGKey](https://github.com/google/jax#random-numbers-are-different). We will split our initial random seed for subsequent operations, so that we do not accidentally reuse the same seed. # - We run inference with the `NUTS` sampler. To run vanilla HMC, we can instead use the [HMC](https://numpyro.readthedocs.io/en/latest/mcmc.html#numpyro.mcmc.HMC) class. # In[8]: # Start from this source of randomness. We will split keys for subsequent operations. rng_key = random.PRNGKey(0) rng_key, rng_key_ = random.split(rng_key) # Run NUTS. kernel = NUTS(model) num_samples = 2000 mcmc = MCMC(kernel, num_warmup=1000, num_samples=num_samples) mcmc.run( rng_key_, marriage=dset.MarriageScaled.values, divorce=dset.DivorceScaled.values ) mcmc.print_summary() samples_1 = mcmc.get_samples() # #### Posterior Distribution over the Regression Parameters # # We notice that the progress bar gives us online statistics on the acceptance probability, step size and number of steps taken per sample while running NUTS. In particular, during warmup, we adapt the step size and mass matrix to achieve a certain target acceptance probability which is 0.8, by default. We were able to successfully adapt our step size to achieve this target in the warmup phase. # # During warmup, the aim is to adapt hyper-parameters such as step size and mass matrix (the HMC algorithm is very sensitive to these hyper-parameters), and to reach the typical set (see [[6](#References)] for more details). If there are any issues in the model specification, the first signal to notice would be low acceptance probabilities or very high number of steps. We use the sample from the end of the warmup phase to seed the MCMC chain (denoted by the second `sample` progress bar) from which we generate the desired number of samples from our target distribution. # # At the end of inference, NumPyro prints the mean, std and 90% CI values for each of the latent parameters. Note that since we standardized our predictors and response variable, we would expect the intercept to have mean 0, as can be seen here. It also prints other convergence diagnostics on the latent parameters in the model, including [effective sample size](https://numpyro.readthedocs.io/en/latest/diagnostics.html#numpyro.diagnostics.effective_sample_size) and the [gelman rubin diagnostic](https://numpyro.readthedocs.io/en/latest/diagnostics.html#numpyro.diagnostics.gelman_rubin) ($\hat{R}$). The value for these diagnostics indicates that the chain has converged to the target distribution. In our case, the "target distribution" is the posterior distribution over the latent parameters that we are interested in. Note that this is often worth verifying with multiple chains for more complicated models. In the end, `samples_1` is a collection (in our case, a `dict` since `init_samples` was a `dict`) containing samples from the posterior distribution for each of the latent parameters in the model. # # To look at our regression fit, let us plot the regression line using our posterior estimates for the regression parameters, along with the 90% Credibility Interval (CI). Note that the [hpdi](https://numpyro.readthedocs.io/en/latest/diagnostics.html#numpyro.diagnostics.hpdi) function in NumPyro's diagnostics module can be used to compute CI. In the functions below, note that the collected samples from the posterior are all along the leading axis. # In[9]: def plot_regression(x, y_mean, y_hpdi): # Sort values for plotting by x axis idx = jnp.argsort(x) marriage = x[idx] mean = y_mean[idx] hpdi = y_hpdi[:, idx] divorce = dset.DivorceScaled.values[idx] # Plot fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(6, 6)) ax.plot(marriage, mean) ax.plot(marriage, divorce, "o") ax.fill_between(marriage, hpdi[0], hpdi[1], alpha=0.3, interpolate=True) return ax # Compute empirical posterior distribution over mu posterior_mu = ( jnp.expand_dims(samples_1["a"], -1) + jnp.expand_dims(samples_1["bM"], -1) * dset.MarriageScaled.values ) mean_mu = jnp.mean(posterior_mu, axis=0) hpdi_mu = hpdi(posterior_mu, 0.9) ax = plot_regression(dset.MarriageScaled.values, mean_mu, hpdi_mu) ax.set( xlabel="Marriage rate", ylabel="Divorce rate", title="Regression line with 90% CI" ); # We can see from the plot, that the CI broadens towards the tails where the data is relatively sparse, as can be expected. # # #### Prior Predictive Distribution # # Let us check that we have set sensible priors by sampling from the prior predictive distribution. NumPyro provides a handy [Predictive](http://num.pyro.ai/en/latest/utilities.html#numpyro.infer.util.Predictive) utility for this purpose. # In[10]: from numpyro.infer import Predictive rng_key, rng_key_ = random.split(rng_key) prior_predictive = Predictive(model, num_samples=100) prior_predictions = prior_predictive(rng_key_, marriage=dset.MarriageScaled.values)[ "obs" ] mean_prior_pred = jnp.mean(prior_predictions, axis=0) hpdi_prior_pred = hpdi(prior_predictions, 0.9) ax = plot_regression(dset.MarriageScaled.values, mean_prior_pred, hpdi_prior_pred) ax.set(xlabel="Marriage rate", ylabel="Divorce rate", title="Predictions with 90% CI"); # #### Posterior Predictive Distribution # # Let us now look at the posterior predictive distribution to see how our predictive distribution looks with respect to the observed divorce rates. To get samples from the posterior predictive distribution, we need to run the model by substituting the latent parameters with samples from the posterior. Note that by default we generate a single prediction for each sample from the joint posterior distribution, but this can be controlled using the `num_samples` argument. # In[11]: rng_key, rng_key_ = random.split(rng_key) predictive = Predictive(model, samples_1) predictions = predictive(rng_key_, marriage=dset.MarriageScaled.values)["obs"] df = dset.filter(["Location"]) df["Mean Predictions"] = jnp.mean(predictions, axis=0) df.head() # #### Predictive Utility With Effect Handlers # # To remove the magic behind `Predictive`, let us see how we can combine [effect handlers](https://numpyro.readthedocs.io/en/latest/handlers.html) with the [vmap](https://github.com/google/jax#auto-vectorization-with-vmap) JAX primitive to implement our own simplified predictive utility function that can do vectorized predictions. # In[12]: def predict(rng_key, post_samples, model, *args, **kwargs): model = handlers.seed(handlers.condition(model, post_samples), rng_key) model_trace = handlers.trace(model).get_trace(*args, **kwargs) return model_trace["obs"]["value"] # vectorize predictions via vmap predict_fn = vmap( lambda rng_key, samples: predict( rng_key, samples, model, marriage=dset.MarriageScaled.values ) ) # Note the use of the `condition`, `seed` and `trace` effect handlers in the `predict` function. # # - The `seed` effect-handler is used to wrap a stochastic function with an initial `PRNGKey` seed. When a sample statement inside the model is called, it uses the existing seed to sample from a distribution but this effect-handler also splits the existing key to ensure that future `sample` calls in the model use the newly split key instead. This is to prevent us from having to explicitly pass in a `PRNGKey` to each `sample` statement in the model. # - The `condition` effect handler conditions the latent sample sites to certain values. In our case, we are conditioning on values from the posterior distribution returned by MCMC. # - The `trace` effect handler runs the model and records the execution trace within an `OrderedDict`. This trace object contains execution metadata that is useful for computing quantities such as the log joint density. # # It should be clear now that the `predict` function simply runs the model by substituting the latent parameters with samples from the posterior (generated by the `mcmc` function) to generate predictions. Note the use of JAX's auto-vectorization transform called [vmap](https://github.com/google/jax#auto-vectorization-with-vmap) to vectorize predictions. Note that if we didn't use `vmap`, we would have to use a native for loop which for each sample which is much slower. Each draw from the posterior can be used to get predictions over all the 50 states. When we vectorize this over all the samples from the posterior using `vmap`, we will get a `predictions_1` array of shape `(num_samples, 50)`. We can then compute the mean and 90% CI of these samples to plot the posterior predictive distribution. We note that our mean predictions match those obtained from the `Predictive` utility class. # In[13]: # Using the same key as we used for Predictive - note that the results are identical. predictions_1 = predict_fn(random.split(rng_key_, num_samples), samples_1) mean_pred = jnp.mean(predictions_1, axis=0) df = dset.filter(["Location"]) df["Mean Predictions"] = mean_pred df.head() # In[14]: hpdi_pred = hpdi(predictions_1, 0.9) ax = plot_regression(dset.MarriageScaled.values, mean_pred, hpdi_pred) ax.set(xlabel="Marriage rate", ylabel="Divorce rate", title="Predictions with 90% CI"); # We have used the same `plot_regression` function as earlier. We notice that our CI for the predictive distribution is much broader as compared to the last plot due to the additional noise introduced by the `sigma` parameter. Most data points lie well within the 90% CI, which indicates a good fit. # # #### Posterior Predictive Density # # Likewise, making use of effect-handlers and `vmap`, we can also compute the log likelihood for this model given the dataset, and the log posterior predictive density [[6](#References)] which is given by # $$ log \prod_{i=1}^{n} \int p(y_i | \theta) p_{post}(\theta) d\theta # \approx \sum_{i=1}^n log \frac{\sum_s p(\theta^{s})}{S} \\ # = \sum_{i=1}^n (log \sum_s p(\theta^{s}) - log(S)) # $$. # # Here, $i$ indexes the observed data points $y$ and $s$ indexes the posterior samples over the latent parameters $\theta$. If the posterior predictive density for a model has a comparatively high value, it indicates that the observed data-points have higher probability under the given model. # In[15]: def log_likelihood(rng_key, params, model, *args, **kwargs): model = handlers.condition(model, params) model_trace = handlers.trace(model).get_trace(*args, **kwargs) obs_node = model_trace["obs"] return obs_node["fn"].log_prob(obs_node["value"]) def log_pred_density(rng_key, params, model, *args, **kwargs): n = list(params.values())[0].shape[0] log_lk_fn = vmap( lambda rng_key, params: log_likelihood(rng_key, params, model, *args, **kwargs) ) log_lk_vals = log_lk_fn(random.split(rng_key, n), params) return (logsumexp(log_lk_vals, 0) - jnp.log(n)).sum() # Note that NumPyro provides the [log_likelihood](http://num.pyro.ai/en/latest/utilities.html#log-likelihood) utility function that can be used directly for computing `log likelihood` as in the first function for any general model. In this tutorial, we would like to emphasize that there is nothing magical about such utility functions, and you can roll out your own inference utilities using NumPyro's effect handling stack. # In[16]: rng_key, rng_key_ = random.split(rng_key) print( "Log posterior predictive density: {}".format( log_pred_density( rng_key_, samples_1, model, marriage=dset.MarriageScaled.values, divorce=dset.DivorceScaled.values, ) ) ) # ### Model 2: Predictor - Median Age of Marriage # # We will now model the divorce rate as a function of the median age of marriage. The computations are mostly a reproduction of what we did for Model 1. Notice the following: # # - Divorce rate is inversely related to the age of marriage. Hence states where the median age of marriage is low will likely have a higher divorce rate. # - We get a higher log likelihood as compared to Model 2, indicating that median age of marriage is likely a much better predictor of divorce rate. # In[17]: rng_key, rng_key_ = random.split(rng_key) mcmc.run(rng_key_, age=dset.AgeScaled.values, divorce=dset.DivorceScaled.values) mcmc.print_summary() samples_2 = mcmc.get_samples() # In[18]: posterior_mu = ( jnp.expand_dims(samples_2["a"], -1) + jnp.expand_dims(samples_2["bA"], -1) * dset.AgeScaled.values ) mean_mu = jnp.mean(posterior_mu, axis=0) hpdi_mu = hpdi(posterior_mu, 0.9) ax = plot_regression(dset.AgeScaled.values, mean_mu, hpdi_mu) ax.set( xlabel="Median marriage age", ylabel="Divorce rate", title="Regression line with 90% CI", ); # In[19]: rng_key, rng_key_ = random.split(rng_key) predictions_2 = Predictive(model, samples_2)(rng_key_, age=dset.AgeScaled.values)["obs"] mean_pred = jnp.mean(predictions_2, axis=0) hpdi_pred = hpdi(predictions_2, 0.9) ax = plot_regression(dset.AgeScaled.values, mean_pred, hpdi_pred) ax.set(xlabel="Median Age", ylabel="Divorce rate", title="Predictions with 90% CI"); # In[20]: rng_key, rng_key_ = random.split(rng_key) print( "Log posterior predictive density: {}".format( log_pred_density( rng_key_, samples_2, model, age=dset.AgeScaled.values, divorce=dset.DivorceScaled.values, ) ) ) # ### Model 3: Predictor - Marriage Rate and Median Age of Marriage # # Finally, we will also model divorce rate as depending on both marriage rate as well as the median age of marriage. Note that the model's posterior predictive density is similar to Model 2 which likely indicates that the marginal information from marriage rate in predicting divorce rate is low when the median age of marriage is already known. # In[21]: rng_key, rng_key_ = random.split(rng_key) mcmc.run( rng_key_, marriage=dset.MarriageScaled.values, age=dset.AgeScaled.values, divorce=dset.DivorceScaled.values, ) mcmc.print_summary() samples_3 = mcmc.get_samples() # In[22]: rng_key, rng_key_ = random.split(rng_key) print( "Log posterior predictive density: {}".format( log_pred_density( rng_key_, samples_3, model, marriage=dset.MarriageScaled.values, age=dset.AgeScaled.values, divorce=dset.DivorceScaled.values, ) ) ) # ### Divorce Rate Residuals by State # # The regression plots above shows that the observed divorce rates for many states differs considerably from the mean regression line. To dig deeper into how the last model (Model 3) under-predicts or over-predicts for each of the states, we will plot the posterior predictive and residuals (`Observed divorce rate - Predicted divorce rate`) for each of the states. # In[23]: # Predictions for Model 3. rng_key, rng_key_ = random.split(rng_key) predictions_3 = Predictive(model, samples_3)( rng_key_, marriage=dset.MarriageScaled.values, age=dset.AgeScaled.values )["obs"] y = jnp.arange(50) fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 16)) pred_mean = jnp.mean(predictions_3, axis=0) pred_hpdi = hpdi(predictions_3, 0.9) residuals_3 = dset.DivorceScaled.values - predictions_3 residuals_mean = jnp.mean(residuals_3, axis=0) residuals_hpdi = hpdi(residuals_3, 0.9) idx = jnp.argsort(residuals_mean) # Plot posterior predictive ax[0].plot(jnp.zeros(50), y, "--") ax[0].errorbar( pred_mean[idx], y, xerr=pred_hpdi[1, idx] - pred_mean[idx], marker="o", ms=5, mew=4, ls="none", alpha=0.8, ) ax[0].plot(dset.DivorceScaled.values[idx], y, marker="o", ls="none", color="gray") ax[0].set( xlabel="Posterior Predictive (red) vs. Actuals (gray)", ylabel="State", title="Posterior Predictive with 90% CI", ) ax[0].set_yticks(y) ax[0].set_yticklabels(dset.Loc.values[idx], fontsize=10) # Plot residuals residuals_3 = dset.DivorceScaled.values - predictions_3 residuals_mean = jnp.mean(residuals_3, axis=0) residuals_hpdi = hpdi(residuals_3, 0.9) err = residuals_hpdi[1] - residuals_mean ax[1].plot(jnp.zeros(50), y, "--") ax[1].errorbar( residuals_mean[idx], y, xerr=err[idx], marker="o", ms=5, mew=4, ls="none", alpha=0.8 ) ax[1].set(xlabel="Residuals", ylabel="State", title="Residuals with 90% CI") ax[1].set_yticks(y) ax[1].set_yticklabels(dset.Loc.values[idx], fontsize=10); # The plot on the left shows the mean predictions with 90% CI for each of the states using Model 3. The gray markers indicate the actual observed divorce rates. The right plot shows the residuals for each of the states, and both these plots are sorted by the residuals, i.e. at the bottom, we are looking at states where the model predictions are higher than the observed rates, whereas at the top, the reverse is true. # # Overall, the model fit seems good because most observed data points like within a 90% CI around the mean predictions. However, notice how the model over-predicts by a large margin for states like Idaho (bottom left), and on the other end under-predicts for states like Maine (top right). This is likely indicative of other factors that we are missing out in our model that affect divorce rate across different states. Even ignoring other socio-political variables, one such factor that we have not yet modeled is the measurement noise given by `Divorce SE` in the dataset. We will explore this in the next section. # ## Regression Model with Measurement Error # # Note that in our previous models, each data point influences the regression line equally. Is this well justified? We will build on the previous model to incorporate measurement error given by `Divorce SE` variable in the dataset. Incorporating measurement noise will be useful in ensuring that observations that have higher confidence (i.e. lower measurement noise) have a greater impact on the regression line. On the other hand, this will also help us better model outliers with high measurement errors. For more details on modeling errors due to measurement noise, refer to Chapter 14 of [[1](#References)]. # # To do this, we will reuse Model 3, with the only change that the final observed value has a measurement error given by `divorce_sd` (notice that this has to be standardized since the `divorce` variable itself has been standardized to mean 0 and std 1). # In[24]: def model_se(marriage, age, divorce_sd, divorce=None): a = numpyro.sample("a", dist.Normal(0.0, 0.2)) bM = numpyro.sample("bM", dist.Normal(0.0, 0.5)) M = bM * marriage bA = numpyro.sample("bA", dist.Normal(0.0, 0.5)) A = bA * age sigma = numpyro.sample("sigma", dist.Exponential(1.0)) mu = a + M + A divorce_rate = numpyro.sample("divorce_rate", dist.Normal(mu, sigma)) numpyro.sample("obs", dist.Normal(divorce_rate, divorce_sd), obs=divorce) # In[25]: # Standardize dset["DivorceScaledSD"] = dset["Divorce SE"] / jnp.std(dset.Divorce.values) # In[26]: rng_key, rng_key_ = random.split(rng_key) kernel = NUTS(model_se, target_accept_prob=0.9) mcmc = MCMC(kernel, num_warmup=1000, num_samples=3000) mcmc.run( rng_key_, marriage=dset.MarriageScaled.values, age=dset.AgeScaled.values, divorce_sd=dset.DivorceScaledSD.values, divorce=dset.DivorceScaled.values, ) mcmc.print_summary() samples_4 = mcmc.get_samples() # ### Effect of Incorporating Measurement Noise on Residuals # # Notice that our values for the regression coefficients is very similar to Model 3. However, introducing measurement noise allows us to more closely match our predictive distribution to the observed values. We can see this if we plot the residuals as earlier. # In[27]: rng_key, rng_key_ = random.split(rng_key) predictions_4 = Predictive(model_se, samples_4)( rng_key_, marriage=dset.MarriageScaled.values, age=dset.AgeScaled.values, divorce_sd=dset.DivorceScaledSD.values, )["obs"] # In[28]: sd = dset.DivorceScaledSD.values residuals_4 = dset.DivorceScaled.values - predictions_4 residuals_mean = jnp.mean(residuals_4, axis=0) residuals_hpdi = hpdi(residuals_4, 0.9) err = residuals_hpdi[1] - residuals_mean idx = jnp.argsort(residuals_mean) y = jnp.arange(50) fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(6, 16)) # Plot Residuals ax.plot(jnp.zeros(50), y, "--") ax.errorbar( residuals_mean[idx], y, xerr=err[idx], marker="o", ms=5, mew=4, ls="none", alpha=0.8 ) # Plot SD ax.errorbar(residuals_mean[idx], y, xerr=sd[idx], ls="none", color="orange", alpha=0.9) # Plot earlier mean residual ax.plot( jnp.mean(dset.DivorceScaled.values - predictions_3, 0)[idx], y, ls="none", marker="o", ms=6, color="black", alpha=0.6, ) ax.set(xlabel="Residuals", ylabel="State", title="Residuals with 90% CI") ax.set_yticks(y) ax.set_yticklabels(dset.Loc.values[idx], fontsize=10) ax.text( -2.8, -7, "Residuals (with error-bars) from current model (in red). " "Black marker \nshows residuals from the previous model (Model 3). " "Measurement \nerror is indicated by orange bar.", ); # The plot above shows the residuals for each of the states, along with the measurement noise given by inner error bar. The gray dots are the mean residuals from our earlier Model 3. Notice how having an additional degree of freedom to model the measurement noise has shrunk the residuals. In particular, for Idaho and Maine, our predictions are now much closer to the observed values after incorporating measurement noise in the model. # # To better see how measurement noise affects the movement of the regression line, let us plot the residuals with respect to the measurement noise. # In[29]: fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(10, 6)) x = dset.DivorceScaledSD.values y1 = jnp.mean(residuals_3, 0) y2 = jnp.mean(residuals_4, 0) ax.plot(x, y1, ls="none", marker="o") ax.plot(x, y2, ls="none", marker="o") for i, (j, k) in enumerate(zip(y1, y2)): ax.plot([x[i], x[i]], [j, k], "--", color="gray") ax.set( xlabel="Measurement Noise", ylabel="Residual", title="Mean residuals (Model 4: red, Model 3: blue)", ); # The plot above shows what has happend in more detail - the regression line itself has moved to ensure a better fit for observations with low measurement noise (left of the plot) where the residuals have shrunk very close to 0. That is to say that data points with low measurement error have a concomitantly higher contribution in determining the regression line. On the other hand, for states with high measurement error (right of the plot), incorporating measurement noise allows us to move our posterior distribution mass closer to the observations resulting in a shrinkage of residuals as well. # ## References # # 1. McElreath, R. (2016). Statistical Rethinking: A Bayesian Course with Examples in R and Stan CRC Press. # 2. Stan Development Team. [Stan User's Guide](https://mc-stan.org/docs/2_19/stan-users-guide/index.html) # 3. Goodman, N.D., and StuhlMueller, A. (2014). [The Design and Implementation of Probabilistic Programming Languages](http://dippl.org/) # 4. Pyro Development Team. [Poutine: A Guide to Programming with Effect Handlers in Pyro](http://pyro.ai/examples/effect_handlers.html) # 5. Hoffman, M.D., Gelman, A. (2011). The No-U-Turn Sampler: Adaptively Setting Path Lengths in Hamiltonian Monte Carlo. # 6. Betancourt, M. (2017). A Conceptual Introduction to Hamiltonian Monte Carlo. # 7. JAX Development Team (2018). [Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more](https://github.com/google/jax) # 8. Gelman, A., Hwang, J., and Vehtari A. [Understanding predictive information criteria for Bayesian models](https://arxiv.org/pdf/1307.5928.pdf)