title: Modeling Shark Attacks in Python with PyMC3

tags: Sharks, PyMC3, Bayesian Statistics

On a recent visit to Cape May, New Jersey I bought and read Shark Attacks of the Jersey Shore: A History, which is an interesting account of verified shark attacks in New Jersey since the nineteenth century.

While I was reading the book, I grew curious about modeling the frequency of shark attacks and went on the hunt for relevant data. There have not been many recent shark attacks in New Jersey, so I broadened my search and found the Global Shark Attack File (GSAF). The GSAF provides a fairly detailed incident log which appears to be updated every few days. This post presents an analysis of the GSAF data in Python using PyMC3. (It is worth mentioning that I am not a shark biologist, and I am sure specialists have produced much more useful and accurate models of shark attacks. Still, this seems like a fun small project as long as we don't take the results too seriously.)

First we make some Python imports and do a bit of housekeeping.

In [1]:
%matplotlib inline
In [2]:
from functools import reduce
from warnings import filterwarnings
In [3]:
from aesara import tensor as at
import arviz as az
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
import pymc3 as pm
import seaborn as sns
import us
You are running the v4 development version of PyMC3 which currently still lacks key features. You probably want to use the stable v3 instead which you can either install via conda or find on the v3 GitHub branch: https://github.com/pymc-devs/pymc3/tree/v3
In [4]:
filterwarnings('ignore', category=UserWarning, module='aesara')
filterwarnings('ignore', category=UserWarning, module='arviz')
filterwarnings('ignore', category=UserWarning, module='pandas')
In [5]:
plt.rcParams['figure.figsize'] = (8, 6)
sns.set(color_codes=True)

Load and Transform the Data

We begin by downloading the data from the GSAF and loading the relevant columns into a Pandas data frame.

In [6]:
%%bash
DATA_URL=http://www.sharkattackfile.net/spreadsheets/GSAF5.xls
DATA_DEST=./data/GSAF5.xls

if [[ ! -e $DATA_DEST ]];
then
    wget -q -O $DATA_DEST $DATA_URL
fi
In [7]:
full_df = pd.read_excel('./data/GSAF5.xls',
                        usecols=[
                           "Case Number", "Date", "Type",
                           "Country", "Area", "Location",
                           "Fatal (Y/N)"
                        ],
                        index_col="Case Number")
full_df["Date"] = full_df["Date"].apply(pd.to_datetime, errors='coerce')
In [8]:
full_df.head()
Out[8]:
Date Type Country Area Location Fatal (Y/N)
Case Number
2021.06.14.b 2021-06-14 Unprovoked USA Florida New Smyrna Beach, Volusia County N
2021.06.14.a 2021-06-14 Unprovoked USA Florida New Smyrna Beach, Volusia County N
2021.06.12 2021-06-12 Provoked ENGLAND West Sussex Littlehampton N
2021.06.11 2021-06-11 Unprovoked AUSTRALIA Western Australia Five Fingers Reef N
2021.05.23 2021-05-23 Unprovoked USA South Carolina Burkes Beach, Hilton Head, Beaufort County N
In [9]:
full_df.tail()
Out[9]:
Date Type Country Area Location Fatal (Y/N)
Case Number
NaN NaT NaN NaN NaN NaN NaN
NaN NaT NaN NaN NaN NaN NaN
NaN NaT NaN NaN NaN NaN NaN
NaN NaT NaN NaN NaN NaN NaN
xx NaT NaN NaN NaN NaN NaN

We can already see by inspecting the data frame that there will be quite a bit of missing data to handle.

In [10]:
full_df.index.isnull().mean()
Out[10]:
0.6595373706846449
In [11]:
full_df.isnull().mean()
Out[11]:
Date           0.775621
Type           0.742880
Country        0.744391
Area           0.760587
Location       0.763726
Fatal (Y/N)    0.763842
dtype: float64

We begin by filtering out rows with missing data in important columns.

In [12]:
FILTERS = [
    ~full_df.index.isnull(),
    ~full_df["Date"].isnull(),
    ~full_df["Type"].isnull()
]
In [13]:
(full_df[reduce(np.logical_and, FILTERS)]
        .isnull()
        .mean())
Out[13]:
Date           0.000000
Type           0.000000
Country        0.004843
Area           0.052413
Location       0.061927
Fatal (Y/N)    0.079398
dtype: float64

We see now that there is significantly less missing data, so we proceed to examine various aspects of the data.

In [14]:
ax = (full_df[reduce(np.logical_and, FILTERS)]
             ["Type"]
             .value_counts()
             .plot(kind='barh'))

ax.set_xscale('log');
ax.set_xlabel("Number of attacks");

ax.invert_yaxis();
ax.set_ylabel("Type of attack");

Unprovoked attacks are by far the most common. We will focus our analysis on this type of attack.

In [15]:
FILTERS.append(full_df["Type"] == "Unprovoked")
In [16]:
ax = (full_df[reduce(np.logical_and, FILTERS)]
             ["Country"]
             .value_counts()
             .plot(kind='barh', figsize=(8, 30)))

ax.set_xscale('log');
ax.set_xlabel("Number of unprovoked attacks");

ax.invert_yaxis();
ax.set_ylabel("Country");

While the data contain information about attacks in many countries, we will focus on the United States. This focus is due partially to the fact that I was in New Jersey when considering this problem, and partially due to the fact that finding standardized data across countries can be more challenging than within different regions of a single country.

In [17]:
FILTERS.append(full_df["Country"] == "USA")
In [18]:
ax = (full_df[reduce(np.logical_and, FILTERS)]
             ["Area"]
             .value_counts()
             .plot(kind='barh', figsize=(8, 12)))

ax.set_xscale('log');
ax.set_xlabel("Number of unprovoked attacks");

ax.invert_yaxis();
ax.set_ylabel("Country");

We see that for unprovoked attacks in the United States, Area roughly corresponds to state, with some territories included as well. Since most of the territories appear very rarely, we restrict our analysis to states so that it is easier to find information about them.

In [19]:
FILTERS.append(
    full_df["Area"].isin([
        state.name for state in us.states.STATES
    ])
)

Finally we look at how the number of shark attacks has changed over time.

In [20]:
ax = (full_df.assign(Year=full_df["Date"].dt.year)
             [reduce(np.logical_and, FILTERS)]
             ["Year"]
             .value_counts()
             .sort_index()
             .plot())

ax.set_xlabel("Year");
ax.set_ylabel("Unprovoked shark attacks\nin the United States");

We see that the number of shark attacks has increased over time. This phenomenon is likely partially due to population growth and partially due to improved reporting mechanisms for shark attacks. We willk keep a relatively modern focus an analysze shark attacks between 1960 and 2020.

In [21]:
YEAR_RANGE = (1960, 2020)

FILTERS.append(full_df["Date"].dt.year.between(*YEAR_RANGE))

Now that we have defined the set of attacks we will model, we produce another data frame including only these attacks and make some light transformations of the data.

In [22]:
df = (full_df[reduce(np.logical_and, FILTERS)]
             .copy()
             .rename(columns={"Area": "State"}))
df["Year"] = df["Date"].dt.year
In [23]:
df.head()
Out[23]:
Date Type Country State Location Fatal (Y/N) Year
Case Number
2020.12.30 2020-12-30 Unprovoked USA California Coronado, San Diego County N 2020
2020.12.08 2020-12-08 Unprovoked USA Hawaii Honolua Bay Y 2020
2020.12.06.b 2020-12-06 Unprovoked USA Oregon Seaside Cove, Clatsop County N 2020
2020.11.26 2020-11-26 Unprovoked USA Hawaii Maui N 2020
2020.10.31 2020-10-31 Unprovoked USA Florida Ormond Beach N 2020
In [24]:
df.tail()
Out[24]:
Date Type Country State Location Fatal (Y/N) Year
Case Number
1904.00.00.a 1970-01-01 00:00:00.000001904 Unprovoked USA Hawaii Off Diamond Head, Honolulu, O'ahu Y 1970
1896.00.00.b 1970-01-01 00:00:00.000001896 Unprovoked USA Florida NaN Y 1970
1883.00.00.a 1970-01-01 00:00:00.000001883 Unprovoked USA South Carolina NaN Y 1970
1882.00.00.b 1970-01-01 00:00:00.000001882 Unprovoked USA Florida In the bay near the naval yard at Pensacola, E... N 1970
1852.00.00 1970-01-01 00:00:00.000001852 Unprovoked USA South Carolina Mount Pleasant, Charleston County Y 1970
In [25]:
df.shape
Out[25]:
(1514, 7)

After applying these filters, just over 1,500 attacks remain. This analysis will focus on the number of shark attacks in a state in a given year. Subsequent posts may analyze other aspects of this data. First we count the number of attacks in a given state in a given year. (We will shorten the phrase "unprovoked shark attacks in the United States between 1960 and 2020" to "attacks" for the remainder of the post.)

In [26]:
attacks_nz = (df.groupby(["State", "Year"])
                .size()
                .rename("Attacks"))
In [27]:
attacks_nz.describe()
Out[27]:
count    333.000000
mean       4.546547
std        6.527174
min        1.000000
25%        1.000000
50%        2.000000
75%        5.000000
max       38.000000
Name: Attacks, dtype: float64

The series attacks_nz includes a row only when there was at least one attack in that state in that year. We also want to include zero entries for state/year combinations that saw now attacks, which we do now by reindexing attacks_nz.

In [28]:
attacks_index = (pd.MultiIndex.from_product((
                        attacks_nz.index
                                  .get_level_values("State")
                                  .unique(),
                        YEAR_RANGE[0] + np.arange(attacks_nz.index
                                                            .get_level_values("Year")
                                                            .values
                                                            .ptp())))
                   .rename("Year", level=1))
attacks_df = (attacks_nz.reindex(attacks_index, fill_value=0)
                        .astype(np.int64)
                        .to_frame())
In [29]:
attacks_df.head()
Out[29]:
Attacks
State Year
Alabama 1960 0
1961 0
1962 0
1963 0
1964 0
In [30]:
attacks_df.tail()
Out[30]:
Attacks
State Year
Washington 2015 0
2016 0
2017 1
2018 0
2019 0

Modeling

We now turn to modeling the data.

In [31]:
ax = attacks_df["Attacks"].hist(bins=attacks_df["Attacks"].max() + 1)

ax.set_xlabel("Number of attacks");

ax.set_yscale('log');
ax.set_ylabel("Number of state-years");

We see that the vast majority of state-years have no shark attacks, and that when there is at least one attack, there are rarely very many. We see that the index of dispersion is significantly larger than one, so the data shows overdispersion.

In [32]:
attacks_df.var() / attacks_df.mean()
Out[32]:
Attacks    12.729978
dtype: float64

Negative Binomial

Due to this overdispersion, we begin with a Negative Binomial model.

Let $y_{i, t}$ represent the number of attacks in the $i$-th state in year $t$. We use the priors

$$ \begin{align*} \mu & \sim \operatorname{Lognormal}(0, 2.5^2) \\ \alpha & \sim \operatorname{Half}-N(2.5^2). \end{align*} $$
In [33]:
with pm.Model() as nb_model:
    μ = pm.Lognormal("μ", 0., 2.5)
    α = pm.HalfNormal("α", 2.5)

We now let $y_{i, t} \sim NB(\mu, \alpha)$

In [34]:
y = attacks_df["Attacks"].values
In [35]:
with nb_model:
    obs = pm.NegativeBinomial("obs", μ, α, observed=y)

We use pymc3 to sample from the posterior distribution of this model.

In [36]:
CHAINS = 3
SEED = 12345

SAMPLE_KWARGS = {
    'cores': CHAINS,
    'random_seed': [SEED + i for i in range(CHAINS)],
    'return_inferencedata': True
}
In [37]:
with nb_model:
    nb_trace = pm.sample(**SAMPLE_KWARGS)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (3 chains in 3 jobs)
NUTS: [μ, α]
100.00% [6000/6000 00:15<00:00 Sampling 3 chains, 0 divergences]
Sampling 3 chains for 1_000 tune and 1_000 draw iterations (3_000 + 3_000 draws total) took 20 seconds.

The standard sampling diagnostics (energy plots, BFMI, and $\hat{R}$) show no cause for concern.

In [38]:
def make_diagnostic_plots(trace, axes=None, min_mult=0.995, max_mult=1.005):
    if axes is None:
        fig, axes = plt.subplots(ncols=2,
                                 sharex=False, sharey=False,
                                 figsize=(16, 6))
        
    az.plot_energy(trace, ax=axes[0])
    
    
    rhat = az.rhat(trace).max()
    axes[1].barh(np.arange(len(rhat.variables)), rhat.to_array(),
                 tick_label=list(rhat.variables.keys()))
    axes[1].axvline(1, c='k', ls='--')

    axes[1].set_xlim(
        min_mult * min(rhat.min().to_array().min(), 1),
        max_mult * max(rhat.max().to_array().max(), 1)
    )
    axes[1].set_xlabel(r"$\hat{R}$")

    axes[1].set_ylabel("Variable")
    
    return fig, axes
In [39]:
make_diagnostic_plots(nb_trace);

Examining the posterior distribution for $\mu$ shows that the model will predicts more than one shark attack per year on average.

In [40]:
az.plot_posterior(nb_trace, var_names=["μ"]);

This prediction of about one shark attack per state per year captures nationwide average well enough, but is clearly not very useful at the state level. We see specifically how bad it is by examining the posterior predictions of the model.

In [41]:
with nb_model:
    nb_pp_trace = pm.sample_posterior_predictive(nb_trace)
100.00% [3000/3000 00:05<00:00]

First we produce a long data frame where each row represents a single posterior predictive sample from the distribution of the number of attacks in a given state-year. This data frame will be the basis for all of our posterior predictive plots.

In [42]:
def make_pp_full_df(pp_obs):
    return (pd.DataFrame(
                {i: samples for i, samples in enumerate(pp_obs)},
                index=attacks_df.index
              )
              .reset_index()
              .melt(id_vars=["State", "Year"],
                    var_name="Sample")
              .set_index(["State", "Year", "Sample"]))
In [43]:
nb_pp_full_df = make_pp_full_df(nb_pp_trace["obs"])
In [44]:
nb_pp_full_df.head()
Out[44]:
value
State Year Sample
Alabama 1960 0 1
1961 0 0
1962 0 0
1963 0 0
1964 0 0
In [45]:
ALPHA = 0.05
In [46]:
def summarize_pp_df(pp_full_df, level=None, alpha=ALPHA):
    if level is None:
        level = ["State", "Year"]
    
    return (pp_full_df.groupby(level=level)
                      ["value"]
                      .agg(
                          mean='mean',
                          sum='sum',
                          low=lambda s: s.quantile(alpha / 2.),
                          high=lambda s: s.quantile(1. - alpha / 2.),
                      )
                      .assign(attacks=attacks_df["Attacks"]
                                                .groupby(level=level)
                                                .sum()))
In [47]:
nb_pp_df = summarize_pp_df(nb_pp_full_df)
In [48]:
nb_pp_df.head()
Out[48]:
mean sum low high attacks
State Year
Alabama 1960 1.231000 3693 0 12.000 0
1961 1.212667 3638 0 12.000 0
1962 1.206000 3618 0 11.025 0
1963 1.232000 3696 0 12.000 0
1964 1.177667 3533 0 11.000 0

The data frame nb_pp_df contains the posterior predictive mean and posterior predictive quantiles for each state-year combination. We now plot the predictions and actual data, broken down by state.

In [49]:
def plot_pp_over_time(data=None, y="mean", *args, **kwargs):
    pp_df = data
    
    ax = plt.gca()
    (pp_df.plot("Year", y,
                c='k', label="Posterior expected value",
                ax=ax))
    (pp_df.reset_index()
          .plot("Year", "attacks",
                kind='scatter', c='k', zorder=5,
                label="Actual", ax=ax))

    ax.fill_between(pp_df["Year"], pp_df["low"], pp_df["high"],
                    color='C0', alpha=0.5,
                    label="95% posterior credible interval")

    ax.set_ylabel("Number of shark attacks")

    return ax
In [50]:
grid = sns.FacetGrid(nb_pp_df.reset_index(), col="State",
                     col_wrap=2, sharey=False, aspect=1.5)
grid.map_dataframe(plot_pp_over_time);

grid.axes[0].legend(loc="upper left");
grid.set_titles("{col_name}");
grid.fig.tight_layout();

We see that the predictions for each state are indeed the same (ignoring some slight Monte Carlo variation), and therefore vastly overpredict shark attacks for most states while massively underpredicting shark attacks for some states.

The plot below which does not include a time axis further reinforces this point.

In [51]:
nb_pp_state_df = summarize_pp_df(nb_pp_full_df, level="State")
ax = (nb_pp_state_df.reset_index()
                    .plot.scatter("mean", "State", color='C0',
                                  xerr=nb_pp_state_df[["low", "high"]]
                                                     .sub(nb_pp_state_df["mean"],
                                                          axis=0)
                                                     .abs()
                                                     .values.T,
                                  zorder=5,
                                  label="Posterior predictive mean"))
(attacks_df.reset_index()
           .plot.scatter("Attacks", "State",
                         color='k', alpha=0.5,
                         label="Actual", ax=ax));

ax.set_xlabel("Number of shark Attacks");
ax.invert_yaxis();

The predictions from this model are particularly bad for states like California, Florida, Hawaii, and the Carolinas, where years with many shark attacks are relatively more common than they are for other states.

Negative Binomial Regression

California, Florida, and to a lesser extent the Carolinas are similar in that they have relatively large populations. All of these states are similar in that their population is relatively concentrated on the coast.

The table below shows that overdispersion is still quite prevalent when we condition on state, so will use negative binomial regression to account for the effect of state-level factors on the number of attacks.

In [52]:
(attacks_df.groupby(level="State")
           .var()
           .div(attacks_df.groupby("State")
                          .mean()))
Out[52]:
Attacks
State
Alabama 1.090395
California 1.408464
Connecticut 1.000000
Delaware 0.966102
Florida 8.373881
Georgia 1.090395
Hawaii 3.457544
Louisiana 0.915254
Maine 1.000000
Massachusetts 1.254237
Mississippi 1.000000
New Jersey 1.457627
New York 2.000000
North Carolina 2.342930
Oregon 1.156634
Rhode Island 1.000000
South Carolina 2.382694
Texas 1.917465
Virginia 0.881356
Washington 0.983051

First we load state-leve population and coastline data from Wikipedia.

In [53]:
POP_URL = "https://en.wikipedia.org/wiki/List_of_U.S._states_and_territories_by_historical_population"

population = (pd.read_html(POP_URL)[3]
                .iloc[:-1]
                .melt(id_vars="Name",
                      var_name="Year",
                      value_name="Population")
                .rename(columns={"Name": "State"})
                .fillna(0)
                .astype({
                    "Year": np.int64,
                    "Population": np.float64
                })
                .set_index(["State", "Year"])
                .sort_index()
                ["Population"])
In [54]:
population.head()
Out[54]:
State    Year
Alabama  1960    3266740.0
         1970    3444165.0
         1980    3893888.0
         1990    4040587.0
         2000    4447100.0
Name: Population, dtype: float64
In [55]:
population.tail()
Out[55]:
State    Year
Wyoming  1980    469557.0
         1990    453588.0
         2000    493782.0
         2010    563626.0
         2020    576851.0
Name: Population, dtype: float64

The series population contains the population of each state according to the United States census conducted every ten years.

In [56]:
COAST_URL = "https://en.wikipedia.org/wiki/List_of_U.S._states_and_territories_by_coastline"

coast_df, _ = pd.read_html(COAST_URL)
coast_df = (coast_df[["State or territory", "Method 1 (CRS)", "Coast/area ratio (ft/mi2)"]]
                    .iloc[:-1])
coast_df.columns = coast_df.columns.droplevel(0)
coast_df = coast_df.drop(["Rank", "Method 2"], axis=1)
coast_df = (coast_df.rename(columns={
                        "State or territory": "State",
                        "Method 1": "Coastline to area"
                    })
                    .set_index("State")
                    .sort_index())
In [57]:
coast_df.head()
Out[57]:
Coastline Coastline to area
State
Alabama 53 mi (85 km) 5.3
Alaska 6,640 mi (10,690 km) 53
American Samoa
California 840 mi (1,350 km) 27
Connecticut 96 mi (154 km) 91
In [58]:
coast_df["Coastline"] = (
    coast_df["Coastline"]
            .str.split(expand=True)
            .iloc[:, 0]
            .str.replace(",", "")
            .str.replace("[–—]", "0", regex=True)
            .astype(np.float64)
)
coast_df["Coastline to area"] = (
    coast_df["Coastline to area"]
            .astype(str)
            .str.replace("[–—]", "-1", regex=True)
            .astype(np.float64)
            .replace(-1, np.nan)
)
In [59]:
coast_df.head()
Out[59]:
Coastline Coastline to area
State
Alabama 53.0 5.3
Alaska 6640.0 53.0
American Samoa 0.0 NaN
California 840.0 27.0
Connecticut 96.0 91.0
In [60]:
coast_df.tail()
Out[60]:
Coastline Coastline to area
State
U.S. Minor Outlying Islands 0.0 NaN
U.S. Virgin Islands 0.0 NaN
Virginia 112.0 14.0
Washington 157.0 12.0
Wisconsin 0.0 NaN

The data frame coast_df contains the length of a state's coastline (in miles) as well as the ratio of its coastline to area (in square miles).

We now combine attacks, population, and coast_df into a single data frame.

In [61]:
attacks_df = (attacks_df.merge(coast_df,
                               left_index=True, right_index=True)
                        .merge(population, how='left',
                               left_index=True, right_index=True)
                        .fillna(method='ffill'))
In [62]:
attacks_df["Population to coastline"] = attacks_df["Population"] / attacks_df["Coastline"]
In [63]:
attacks_df.head()
Out[63]:
Attacks Coastline Coastline to area Population Population to coastline
State Year
Alabama 1960 0 53.0 5.3 3266740.0 61636.603774
1961 0 53.0 5.3 3266740.0 61636.603774
1962 0 53.0 5.3 3266740.0 61636.603774
1963 0 53.0 5.3 3266740.0 61636.603774
1964 0 53.0 5.3 3266740.0 61636.603774
In [64]:
attacks_df.tail()
Out[64]:
Attacks Coastline Coastline to area Population Population to coastline
State Year
Washington 2015 0 157.0 12.0 6724540.0 42831.464968
2016 0 157.0 12.0 6724540.0 42831.464968
2017 1 157.0 12.0 6724540.0 42831.464968
2018 0 157.0 12.0 6724540.0 42831.464968
2019 0 157.0 12.0 6724540.0 42831.464968

Here the population data is from the most recent United States census prior to the year in question (thanks to fillna(method='ffill'). Below we plot the relationship between the four explanatory variables and the number of attacks.

In [65]:
fig, axes = plt.subplots(ncols=2, nrows=2, sharey=True,
                         figsize=(12, 9))

for col, ax in zip(attacks_df.columns[1:], axes.ravel()):
    attacks_df.plot.scatter(col, "Attacks",
                            color='C0', alpha=0.5,
                            ax=ax);
    
    ax.set_xscale('log');
    ax.set_yscale('log');

fig.tight_layout();

The top two and bottom two plots are very similar. We choose to use coastline length (in miles) and population as our predictors since those relationships seem to be a bit more linear on the log-log scale.

We standardize the logarithms of coastline length and population before using them as predictors.

In [66]:
def standardize(x):
    return (x - x.mean()) / x.std()
In [67]:
log_coast = np.log(attacks_df["Coastline"].values)
x_coast = at.constant(standardize(log_coast))
In [68]:
log_pop = np.log(attacks_df["Population"].values)
x_pop = at.constant(standardize(log_pop))

We use the priors $\beta_0, \beta_{\text{coast}}, \beta_{\text{pop}} \sim N(0, 2.5^2)$ on the regression coefficients and set

$$\eta_{i, t} = \beta_0 + \beta_{\text{coast}} \cdot x_{\text{coast}, i} + \beta_{\text{pop}} \cdot x_{\text{pop}, i, t}.$$

The mean is then $\mu_{i, t} = \exp \eta_{i, t}.$

In [69]:
with pm.Model() as nb_reg_model:
    β0 = pm.Normal("β0", 0., 2.5)
    β_coast = pm.Normal("β_coast", 0., 2.5)
    β_pop = pm.Normal("β_pop", 0., 2.5)
    η = β0 + β_coast * x_coast + β_pop * x_pop
    μ = at.exp(η)

As in the previous model, $\alpha \sim \operatorname{Half}-N(2.5^2)$ and $y_{i, t} \sim NB(\mu_{i, t}, \alpha)$.

In [70]:
with nb_reg_model:
    α = pm.HalfNormal("α", 2.5)
    obs = pm.NegativeBinomial("obs", μ, α, observed=y)

We again sample from the posterior distribution of the model.

In [71]:
with nb_reg_model:
    nb_reg_trace = pm.sample(**SAMPLE_KWARGS)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (3 chains in 3 jobs)
NUTS: [β0, β_coast, β_pop, α]
100.00% [6000/6000 00:24<00:00 Sampling 3 chains, 0 divergences]
Sampling 3 chains for 1_000 tune and 1_000 draw iterations (3_000 + 3_000 draws total) took 25 seconds.

As before, the standard sampling diagnostics show no cause for concern.

In [72]:
make_diagnostic_plots(nb_reg_trace);

To see if the inclusion of these predictors has improved the model fit, we sample from and plot the posterior predictive distributions per state-year.

In [73]:
with nb_reg_model:
    nb_reg_pp_trace = pm.sample_posterior_predictive(nb_reg_trace)
100.00% [3000/3000 00:02<00:00]
In [74]:
nb_reg_pp_full_df = make_pp_full_df(nb_reg_pp_trace["obs"])
In [75]:
nb_reg_pp_df = summarize_pp_df(nb_reg_pp_full_df)
grid = sns.FacetGrid(nb_reg_pp_df.reset_index(), col="State",
                     col_wrap=2, sharey=False, aspect=1.5)
grid.map_dataframe(plot_pp_over_time);

grid.axes[0].legend(loc="upper left");
grid.set_titles("{col_name}");
grid.fig.tight_layout();