!pip install -q numpyro arviz
import math
import os
import warnings
import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import jax.numpy as jnp
from jax import nn, random, vmap
from jax.scipy.special import expit
import numpyro
import numpyro.distributions as dist
import numpyro.optim as optim
from numpyro.diagnostics import print_summary
from numpyro.infer import MCMC, NUTS, Predictive, SVI, Trace_ELBO, log_likelihood
from numpyro.infer.autoguide import AutoLaplaceApproximation
if "SVG" in os.environ:
%config InlineBackend.figure_formats = ["svg"]
warnings.formatwarning = lambda message, category, *args, **kwargs: "{}: {}\n".format(
category.__name__, message
)
az.style.use("arviz-darkgrid")
numpyro.set_platform("cpu")
numpyro.set_host_device_count(4)
chimpanzees = pd.read_csv("../data/chimpanzees.csv", sep=";")
d = chimpanzees
d["treatment"] = d.prosoc_left + 2 * d.condition
d.reset_index().groupby(["condition", "prosoc_left", "treatment"]).count()["index"]
condition prosoc_left treatment 0 0 0 126 1 1 126 1 0 2 126 1 3 126 Name: index, dtype: int64
def model(pulled_left=None):
a = numpyro.sample("a", dist.Normal(0, 10))
logit_p = a
numpyro.sample("pulled_left", dist.Binomial(logits=logit_p), obs=pulled_left)
m11_1 = AutoLaplaceApproximation(model)
svi = SVI(model, m11_1, optim.Adam(1), Trace_ELBO(), pulled_left=d.pulled_left.values)
svi_result = svi.run(random.PRNGKey(0), 1000)
p11_1 = svi_result.params
100%|██████████| 1000/1000 [00:00<00:00, 1707.80it/s, init loss: 346.7780, avg. loss [951-1000]: 346.1921]
prior = Predictive(m11_1.model, num_samples=10000)(random.PRNGKey(1999))
p = expit(prior["a"])
az.plot_kde(p)
plt.show()
def model(treatment, pulled_left=None):
a = numpyro.sample("a", dist.Normal(0, 1.5))
b = numpyro.sample("b", dist.Normal(0, 10).expand([4]))
logit_p = a + b[treatment]
numpyro.sample("pulled_left", dist.Binomial(logits=logit_p), obs=pulled_left)
m11_2 = AutoLaplaceApproximation(model)
svi = SVI(
model,
m11_2,
optim.Adam(1),
Trace_ELBO(),
treatment=d.treatment.values,
pulled_left=d.pulled_left.values,
)
svi_result = svi.run(random.PRNGKey(0), 1000)
p11_2 = svi_result.params
prior = Predictive(model, num_samples=int(1e4))(
random.PRNGKey(1999), treatment=0, pulled_left=0
)
p = vmap(lambda k: expit(prior["a"] + prior["b"][:, k]), 0, 1)(jnp.arange(4))
100%|██████████| 1000/1000 [00:00<00:00, 1238.06it/s, init loss: 414.9600, avg. loss [951-1000]: 351.7402]
az.plot_kde(jnp.abs(p[:, 0] - p[:, 1]), bw=0.3)
plt.show()
def model(treatment, pulled_left=None):
a = numpyro.sample("a", dist.Normal(0, 1.5))
b = numpyro.sample("b", dist.Normal(0, 0.5).expand([4]))
logit_p = a + b[treatment]
numpyro.sample("pulled_left", dist.Binomial(logits=logit_p), obs=pulled_left)
m11_3 = AutoLaplaceApproximation(model)
svi = SVI(
model,
m11_3,
optim.Adam(1),
Trace_ELBO(),
treatment=d.treatment.values,
pulled_left=d.pulled_left.values,
)
svi_result = svi.run(random.PRNGKey(0), 1000)
p11_3 = svi_result.params
prior = Predictive(model, num_samples=int(1e4))(
random.PRNGKey(1999), treatment=0, pulled_left=0
)
p = vmap(lambda k: expit(prior["a"] + prior["b"][:, k]), 0, 1)(jnp.arange(4))
jnp.mean(jnp.abs(p[:, 0] - p[:, 1]))
100%|██████████| 1000/1000 [00:00<00:00, 1329.70it/s, init loss: 414.5659, avg. loss [951-1000]: 340.4073]
DeviceArray(0.09770478, dtype=float32)
# trimmed data list
dat_list = {
"pulled_left": d.pulled_left.values,
"actor": d.actor.values - 1,
"treatment": d.treatment.values,
}
def model(actor, treatment, pulled_left=None, link=False):
a = numpyro.sample("a", dist.Normal(0, 1.5).expand([7]))
b = numpyro.sample("b", dist.Normal(0, 0.5).expand([4]))
logit_p = a[actor] + b[treatment]
if link:
numpyro.deterministic("p", expit(logit_p))
numpyro.sample("pulled_left", dist.Binomial(logits=logit_p), obs=pulled_left)
m11_4 = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=4)
m11_4.run(random.PRNGKey(0), **dat_list)
m11_4.print_summary(0.89)
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
mean std median 5.5% 94.5% n_eff r_hat a[0] -0.44 0.33 -0.44 -0.94 0.09 653.23 1.00 a[1] 3.90 0.75 3.83 2.67 4.96 1776.39 1.00 a[2] -0.75 0.33 -0.75 -1.29 -0.23 765.18 1.00 a[3] -0.73 0.34 -0.73 -1.25 -0.18 707.96 1.00 a[4] -0.45 0.33 -0.46 -0.93 0.13 721.43 1.00 a[5] 0.49 0.32 0.49 -0.04 0.97 666.96 1.00 a[6] 1.97 0.43 1.95 1.33 2.69 881.18 1.00 b[0] -0.04 0.28 -0.05 -0.48 0.40 624.20 1.00 b[1] 0.48 0.28 0.48 0.00 0.90 565.12 1.00 b[2] -0.39 0.28 -0.38 -0.84 0.05 641.65 1.00 b[3] 0.36 0.28 0.36 -0.09 0.81 642.48 1.00 Number of divergences: 0
post = m11_4.get_samples(group_by_chain=True)
p_left = expit(post["a"])
az.plot_forest({"p_left": p_left}, combined=True, hdi_prob=0.89)
plt.gca().set(xlim=(-0.01, 1.01))
plt.show()
labs = ["R/N", "L/N", "R/P", "L/P"]
az.plot_forest(
m11_4.get_samples(group_by_chain=True),
combined=True,
var_names="b",
hdi_prob=0.89,
)
plt.gca().set_yticklabels(labs[::-1])
plt.show()
diffs = {
"db13": post["b"][..., 0] - post["b"][..., 2],
"db24": post["b"][..., 1] - post["b"][..., 3],
}
az.plot_forest(diffs, combined=True)
plt.show()
pl = d.groupby(["actor", "treatment"])["pulled_left"].mean().unstack()
pl.iloc[0, :]
treatment 0 0.333333 1 0.500000 2 0.277778 3 0.555556 Name: 1, dtype: float64
ax = plt.subplot(
xlim=(0.5, 28.5),
ylim=(0, 1.05),
xlabel="",
ylabel="proportion left lever",
xticks=[],
)
plt.yticks(ticks=[0, 0.5, 1], labels=[0, 0.5, 1])
ax.axhline(0.5, c="k", lw=1, ls="--")
for j in range(1, 8):
ax.axvline((j - 1) * 4 + 4.5, c="k", lw=0.5)
for j in range(1, 8):
ax.annotate(
"actor {}".format(j),
((j - 1) * 4 + 2.5, 1.1),
ha="center",
va="center",
annotation_clip=False,
)
for j in [1] + list(range(3, 8)):
ax.plot((j - 1) * 4 + jnp.array([1, 3]), pl.loc[j, [0, 2]], "b")
ax.plot((j - 1) * 4 + jnp.array([2, 4]), pl.loc[j, [1, 3]], "b")
x = jnp.arange(1, 29).reshape(7, 4)
ax.scatter(
x[:, [0, 1]].reshape(-1),
pl.values[:, [0, 1]].reshape(-1),
edgecolor="b",
facecolor="w",
zorder=3,
)
ax.scatter(
x[:, [2, 3]].reshape(-1), pl.values[:, [2, 3]].reshape(-1), marker=".", c="b", s=80
)
yoff = 0.01
ax.annotate("R/N", (1, pl.loc[1, 0] - yoff), ha="center", va="top")
ax.annotate("L/N", (2, pl.loc[1, 1] + yoff), ha="center", va="bottom")
ax.annotate("R/P", (3, pl.loc[1, 2] - yoff), ha="center", va="top")
ax.annotate("L/P", (4, pl.loc[1, 3] + yoff), ha="center", va="bottom")
ax.set_title("observed proportions\n")
plt.show()
dat = {"actor": jnp.repeat(jnp.arange(7), 4), "treatment": jnp.tile(jnp.arange(4), 7)}
pred = Predictive(m11_4.sampler.model, m11_4.get_samples(), return_sites=["p"])
p_post = pred(random.PRNGKey(1), link=True, **dat)["p"]
p_mu = jnp.mean(p_post, 0)
p_ci = jnp.percentile(p_post, q=jnp.array([5.5, 94.5]), axis=0)
d["side"] = d.prosoc_left # right 0, left 1
d["cond"] = d.condition # no partner 0, partner 1
dat_list2 = {
"pulled_left": d.pulled_left.values,
"actor": d.actor.values - 1,
"side": d.side.values,
"cond": d.cond.values,
}
def model(actor, side, cond, pulled_left=None):
a = numpyro.sample("a", dist.Normal(0, 1.5).expand([7]))
bs = numpyro.sample("bs", dist.Normal(0, 0.5).expand([2]))
bc = numpyro.sample("bc", dist.Normal(0, 0.5).expand([2]))
logit_p = a[actor] + bs[side] + bc[cond]
numpyro.sample("pulled_left", dist.Binomial(logits=logit_p), obs=pulled_left)
m11_5 = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=4)
m11_5.run(random.PRNGKey(0), **dat_list2)
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
az.compare(
{"m11.5": az.from_numpyro(m11_5), "m11.4": az.from_numpyro(m11_4)},
ic="loo",
scale="deviance",
)
UserWarning: The default method used to estimate the weights for each model,has changed from BB-pseudo-BMA to stacking
rank | loo | p_loo | d_loo | weight | se | dse | warning | loo_scale | |
---|---|---|---|---|---|---|---|---|---|
m11.5 | 0 | 530.868595 | 7.771968 | 0.000000 | 1.0 | 19.077177 | 0.000000 | False | deviance |
m11.4 | 1 | 532.349813 | 8.527458 | 1.481217 | 0.0 | 18.951459 | 1.212535 | False | deviance |
post = m11_4.get_samples()
post["log_lik"] = log_likelihood(m11_4.sampler.model, post, **dat_list)["pulled_left"]
{k: v.shape for k, v in post.items()}
{'a': (2000, 7), 'b': (2000, 4), 'log_lik': (2000, 504)}
def m11_4_pe_code(params, log_lik=False):
a_logprob = jnp.sum(dist.Normal(0, 1.5).log_prob(params["a"]))
b_logprob = jnp.sum(dist.Normal(0, 0.5).log_prob(params["b"]))
logit_p = params["a"][dat_list["actor"]] + params["b"][dat_list["treatment"]]
pulled_left_logprob = dist.Binomial(logits=logit_p).log_prob(
dat_list["pulled_left"]
)
if log_lik:
return pulled_left_logprob
return -(a_logprob + b_logprob + jnp.sum(pulled_left_logprob))
m11_4_pe = MCMC(
NUTS(potential_fn=m11_4_pe_code), num_warmup=1000, num_samples=1000, num_chains=4
)
init_params = {"a": jnp.zeros((4, 7)), "b": jnp.zeros((4, 4))}
m11_4_pe.run(random.PRNGKey(0), init_params=init_params)
log_lik = vmap(lambda p: m11_4_pe_code(p, log_lik=True))(m11_4_pe.get_samples())
m11_4_pe_az = az.from_numpyro(m11_4_pe)
m11_4_pe_az.sample_stats["log_likelihood"] = (
("chain", "draw", "log_lik"),
jnp.reshape(log_lik, (4, 1000, -1)),
)
az.compare(
{"m11.4_pe": m11_4_pe_az, "m11.4": az.from_numpyro(m11_4)},
ic="waic",
scale="deviance",
)
0%| | 0/2000 [00:00<?, ?it/s]
0%| | 0/2000 [00:00<?, ?it/s]
0%| | 0/2000 [00:00<?, ?it/s]
0%| | 0/2000 [00:00<?, ?it/s]
UserWarning: The default method used to estimate the weights for each model,has changed from BB-pseudo-BMA to stacking
rank | waic | p_waic | d_waic | weight | se | dse | warning | waic_scale | |
---|---|---|---|---|---|---|---|---|---|
m11.4_pe | 0 | 532.006465 | 8.366444 | 0.000000 | 1.0 | 18.875380 | 0.000000 | False | deviance |
m11.4 | 1 | 532.342706 | 8.523905 | 0.336241 | 0.0 | 18.951188 | 0.192587 | False | deviance |
post = m11_4.get_samples()
jnp.mean(jnp.exp(post["b"][:, 3] - post["b"][:, 1]))
DeviceArray(0.9210905, dtype=float32)
chimpanzees = pd.read_csv("../data/chimpanzees.csv", sep=";")
d = chimpanzees
d["treatment"] = d.prosoc_left + 2 * d.condition
d["side"] = d.prosoc_left # right 0, left 1
d["cond"] = d.condition # no partner 0, partner 1
d_aggregated = (
d.groupby(["treatment", "actor", "side", "cond"])["pulled_left"].sum().reset_index()
)
d_aggregated.rename(columns={"pulled_left": "left_pulls"}, inplace=True)
dat = dict(zip(d_aggregated.columns, d_aggregated.values.T))
def model(actor, treatment, left_pulls):
a = numpyro.sample("a", dist.Normal(0, 1.5).expand([7]))
b = numpyro.sample("b", dist.Normal(0, 0.5).expand([4]))
logit_p = a[actor] + b[treatment]
numpyro.sample("left_pulls", dist.Binomial(18, logits=logit_p), obs=left_pulls)
m11_6 = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=4)
m11_6.run(
random.PRNGKey(0),
actor=dat["actor"] - 1,
treatment=dat["treatment"],
left_pulls=dat["left_pulls"],
)
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
try:
az.compare(
{"m11.6": az.from_numpyro(m11_6), "m11.4": az.from_numpyro(m11_4)},
ic="loo",
scale="deviance",
)
except Exception as e:
warnings.warn("\n{}: {}".format(type(e).__name__, e))
UserWarning: The default method used to estimate the weights for each model,has changed from BB-pseudo-BMA to stacking UserWarning: Estimated shape parameter of Pareto distribution is greater than 0.7 for one or more samples. You should consider using a more robust model, this is because importance sampling is less likely to work well if the marginal posterior and LOO posterior are very different. This is more likely to happen with a non-robust model and highly influential observations. UserWarning: ValueError: The number of observations should be the same across all models
# deviance of aggregated 6-in-9
print(-2 * dist.Binomial(9, 0.2).log_prob(6))
# deviance of dis-aggregated
print(
-2 * jnp.sum(dist.Bernoulli(0.2).log_prob(jnp.array([1, 1, 1, 1, 1, 1, 0, 0, 0])))
)
11.790477 20.652117
UCBadmit = pd.read_csv("../data/UCBadmit.csv", sep=";")
d = UCBadmit
dat_list = dict(
admit=d.admit.values,
applications=d.applications.values,
gid=(d["applicant.gender"] != "male").astype(int).values,
)
def model(gid, applications, admit=None):
a = numpyro.sample("a", dist.Normal(0, 1.5).expand([2]))
logit_p = a[gid]
numpyro.sample("admit", dist.Binomial(applications, logits=logit_p), obs=admit)
m11_7 = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=4)
m11_7.run(random.PRNGKey(0), **dat_list)
m11_7.print_summary(0.89)
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
mean std median 5.5% 94.5% n_eff r_hat a[0] -0.22 0.04 -0.22 -0.28 -0.16 2000.51 1.00 a[1] -0.83 0.05 -0.83 -0.91 -0.75 1905.84 1.00 Number of divergences: 0
post = m11_7.get_samples()
diff_a = post["a"][:, 0] - post["a"][:, 1]
diff_p = expit(post["a"][:, 0]) - expit(post["a"][:, 1])
print_summary({"diff_a": diff_a, "diff_p": diff_p}, 0.89, False)
mean std median 5.5% 94.5% n_eff r_hat diff_a 0.61 0.06 0.61 0.51 0.71 1924.35 1.00 diff_p 0.14 0.01 0.14 0.12 0.16 1932.49 1.00
post = m11_7.get_samples()
admit_pred = Predictive(m11_7.sampler.model, post)(
random.PRNGKey(2), gid=dat_list["gid"], applications=dat_list["applications"]
)["admit"]
admit_rate = admit_pred / d.applications.values
plt.errorbar(
range(1, 13),
jnp.mean(admit_rate, 0),
jnp.std(admit_rate, 0) / 2,
fmt="o",
c="k",
mfc="none",
ms=7,
elinewidth=1,
)
plt.plot(range(1, 13), jnp.percentile(admit_rate, 5.5, 0), "k+")
plt.plot(range(1, 13), jnp.percentile(admit_rate, 94.5, 0), "k+")
# draw lines connecting points from same dept
for i in range(1, 7):
x = 1 + 2 * (i - 1)
y1 = d.admit.iloc[x - 1] / d.applications.iloc[x - 1]
y2 = d.admit.iloc[x] / d.applications.iloc[x]
plt.plot((x, x + 1), (y1, y2), "bo-")
plt.annotate(
d.dept.iloc[x], (x + 0.5, (y1 + y2) / 2 + 0.05), ha="center", color="royalblue"
)
plt.gca().set(ylim=(0, 1), xticks=range(1, 13), ylabel="admit", xlabel="case")
plt.show()
dat_list["dept_id"] = jnp.repeat(jnp.arange(6), 2)
def model(gid, dept_id, applications, admit=None):
a = numpyro.sample("a", dist.Normal(0, 1.5).expand([2]))
delta = numpyro.sample("delta", dist.Normal(0, 1.5).expand([6]))
logit_p = a[gid] + delta[dept_id]
numpyro.sample("admit", dist.Binomial(applications, logits=logit_p), obs=admit)
m11_8 = MCMC(NUTS(model), num_warmup=2000, num_samples=2000, num_chains=4)
m11_8.run(random.PRNGKey(0), **dat_list)
m11_8.print_summary(0.89)
0%| | 0/4000 [00:00<?, ?it/s]
0%| | 0/4000 [00:00<?, ?it/s]
0%| | 0/4000 [00:00<?, ?it/s]
0%| | 0/4000 [00:00<?, ?it/s]
mean std median 5.5% 94.5% n_eff r_hat a[0] -0.54 0.55 -0.54 -1.38 0.35 581.04 1.01 a[1] -0.44 0.55 -0.44 -1.31 0.42 579.82 1.01 delta[0] 1.12 0.55 1.12 0.25 1.99 584.56 1.01 delta[1] 1.08 0.55 1.07 0.20 1.95 589.21 1.01 delta[2] -0.14 0.55 -0.14 -1.02 0.73 583.25 1.01 delta[3] -0.17 0.55 -0.17 -1.03 0.72 588.07 1.01 delta[4] -0.62 0.55 -0.62 -1.50 0.25 584.70 1.01 delta[5] -2.17 0.57 -2.17 -3.12 -1.31 602.85 1.00 Number of divergences: 0
post = m11_8.get_samples()
diff_a = post["a"][:, 0] - post["a"][:, 1]
diff_p = expit(post["a"][:, 0]) - expit(post["a"][:, 1])
print_summary({"diff_a": diff_a, "diff_p": diff_p}, 0.89, False)
mean std median 5.5% 94.5% n_eff r_hat diff_a -0.10 0.08 -0.10 -0.22 0.03 9440.98 1.00 diff_p -0.02 0.02 -0.02 -0.05 0.01 6953.82 1.00
pg = jnp.stack(
list(
map(
lambda k: jnp.divide(
d.applications[np.asarray(dat_list["dept_id"]) == k].values,
d.applications[np.asarray(dat_list["dept_id"]) == k].sum(),
),
range(6),
)
),
axis=0,
).T
pg = pd.DataFrame(pg, index=["male", "female"], columns=d.dept.unique())
pg.round(2)
A | B | C | D | E | F | |
---|---|---|---|---|---|---|
male | 0.88 | 0.96 | 0.35 | 0.53 | 0.33 | 0.52 |
female | 0.12 | 0.04 | 0.65 | 0.47 | 0.67 | 0.48 |
y = dist.Binomial(1000, 1 / 1000).sample(random.PRNGKey(0), (int(1e5),))
jnp.mean(y), jnp.var(y)
(DeviceArray(0.99552995, dtype=float32), DeviceArray(1.00013, dtype=float32))
Kline = pd.read_csv("../data/Kline.csv", sep=";")
d = Kline
d
culture | population | contact | total_tools | mean_TU | |
---|---|---|---|---|---|
0 | Malekula | 1100 | low | 13 | 3.2 |
1 | Tikopia | 1500 | low | 22 | 4.7 |
2 | Santa Cruz | 3600 | low | 24 | 4.0 |
3 | Yap | 4791 | high | 43 | 5.0 |
4 | Lau Fiji | 7400 | high | 33 | 5.0 |
5 | Trobriand | 8000 | high | 19 | 4.0 |
6 | Chuuk | 9200 | high | 40 | 3.8 |
7 | Manus | 13000 | low | 28 | 6.6 |
8 | Tonga | 17500 | high | 55 | 5.4 |
9 | Hawaii | 275000 | low | 71 | 6.6 |
d["P"] = d.population.apply(math.log).pipe(lambda x: (x - x.mean()) / x.std())
d["contact_id"] = (d.contact == "high").astype(int)
x = jnp.linspace(0, 100, 200)
plt.plot(x, jnp.exp(dist.LogNormal(0, 10).log_prob(x)))
plt.show()
a = dist.Normal(0, 10).sample(random.PRNGKey(0), (int(1e4),))
lambda_ = jnp.exp(a)
jnp.mean(lambda_)
DeviceArray(1.1725839e+12, dtype=float32)
x = jnp.linspace(0, 100, 200)
plt.plot(x, jnp.exp(dist.LogNormal(3, 0.5).log_prob(x)))
plt.show()
N = 100
a = dist.Normal(3, 0.5).sample(random.PRNGKey(0), (N,))
b = dist.Normal(0, 10).sample(random.PRNGKey(1), (N,))
plt.subplot(xlim=(-2, 2), ylim=(0, 100))
x = jnp.linspace(-2, 2, 100)
for i in range(N):
plt.plot(x, jnp.exp(a[i] + b[i] * x), c="k", alpha=0.5)
with numpyro.handlers.seed(rng_seed=10):
N = 100
a = numpyro.sample("a", dist.Normal(3, 0.5).expand([N]))
b = numpyro.sample("a", dist.Normal(0, 0.2).expand([N]))
plt.subplot(xlim=(-2, 2), ylim=(0, 100))
x = jnp.linspace(-2, 2, 100)
for i in range(N):
plt.plot(x, jnp.exp(a[i] + b[i] * x), c="k", alpha=0.5)
x_seq = jnp.linspace(jnp.log(100), jnp.log(200000), num=100)
lambda_ = vmap(lambda x: jnp.exp(a + b * x), out_axes=1)(x_seq)
plt.subplot(
xlim=(jnp.min(x_seq).item(), jnp.max(x_seq).item()),
ylim=(0, 500),
xlabel="log population",
ylabel="total tools",
)
for i in range(N):
plt.plot(x_seq, lambda_[i], c="k", alpha=0.5)
plt.subplot(
xlim=(jnp.min(jnp.exp(x_seq)).item(), jnp.max(jnp.exp(x_seq)).item()),
ylim=(0, 500),
xlabel="population",
ylabel="total tools",
)
for i in range(N):
plt.plot(jnp.exp(x_seq), lambda_[i], c="k", alpha=0.5)
dat = dict(T=d.total_tools.values, P=d.P.values, cid=d.contact_id.values)
# intercept only
def model(T=None):
a = numpyro.sample("a", dist.Normal(3, 0.5))
lambda_ = numpyro.deterministic("lambda", jnp.exp(a))
numpyro.sample("T", dist.Poisson(lambda_), obs=T)
m11_9 = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=4)
m11_9.run(random.PRNGKey(0), dat["T"])
# interaction model
def model(cid, P, T=None):
a = numpyro.sample("a", dist.Normal(3, 0.5).expand([2]))
b = numpyro.sample("b", dist.Normal(0, 0.2).expand([2]))
lambda_ = numpyro.deterministic("lambda", jnp.exp(a[cid] + b[cid] * P))
numpyro.sample("T", dist.Poisson(lambda_), obs=T)
m11_10 = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=4)
m11_10.run(random.PRNGKey(0), **dat)
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
az.compare(
{"m11.9": az.from_numpyro(m11_9), "m11.10": az.from_numpyro(m11_10)},
ic="loo",
scale="deviance",
)
UserWarning: The default method used to estimate the weights for each model,has changed from BB-pseudo-BMA to stacking UserWarning: Estimated shape parameter of Pareto distribution is greater than 0.7 for one or more samples. You should consider using a more robust model, this is because importance sampling is less likely to work well if the marginal posterior and LOO posterior are very different. This is more likely to happen with a non-robust model and highly influential observations. UserWarning: Estimated shape parameter of Pareto distribution is greater than 0.7 for one or more samples. You should consider using a more robust model, this is because importance sampling is less likely to work well if the marginal posterior and LOO posterior are very different. This is more likely to happen with a non-robust model and highly influential observations.
rank | loo | p_loo | d_loo | weight | se | dse | warning | loo_scale | |
---|---|---|---|---|---|---|---|---|---|
m11.10 | 0 | 85.544697 | 7.124397 | 0.00000 | 0.98207 | 12.345151 | 0.000000 | True | deviance |
m11.9 | 1 | 142.276767 | 8.786926 | 56.73207 | 0.01793 | 32.278971 | 31.019185 | True | deviance |
k = az.loo(az.from_numpyro(m11_10), pointwise=True).pareto_k.values
cex = 1 + (k - jnp.min(k)) / (jnp.max(k) - jnp.min(k))
plt.scatter(
dat["P"],
dat["T"],
s=40 * cex,
edgecolors=["none" if i == 1 else "b" for i in dat["cid"]],
facecolors=["none" if i == 0 else "b" for i in dat["cid"]],
)
plt.gca().set(xlabel="log population (std)", ylabel="total tools", ylim=(0, 75))
# set up the horizontal axis values to compute predictions at
ns = 100
P_seq = jnp.linspace(-1.4, 3, num=ns)
# predictions for cid=0 (low contact)
post = m11_10.get_samples()
post.pop("lambda")
lambda_ = Predictive(m11_10.sampler.model, post)(
random.PRNGKey(1), P=P_seq, cid=0
)["lambda"]
lmu = jnp.mean(lambda_, 0)
lci = jnp.percentile(lambda_, jnp.array([5.5, 94.5]), 0)
plt.plot(P_seq, lmu, "k--", lw=1.5)
plt.fill_between(P_seq, lci[0], lci[1], color="k", alpha=0.2)
# predictions for cid=1 (high contact)
lambda_ = Predictive(m11_10.sampler.model, post)(
random.PRNGKey(1), P=P_seq, cid=1
)["lambda"]
lmu = jnp.mean(lambda_, 0)
lci = jnp.percentile(lambda_, jnp.array([5.5, 94.5]), 0)
plt.plot(P_seq, lmu, "k", lw=1.5)
plt.fill_between(P_seq, lci[0], lci[1], color="k", alpha=0.2)
plt.show()
UserWarning: Estimated shape parameter of Pareto distribution is greater than 0.7 for one or more samples. You should consider using a more robust model, this is because importance sampling is less likely to work well if the marginal posterior and LOO posterior are very different. This is more likely to happen with a non-robust model and highly influential observations.
cex = 1 + (k - jnp.min(k)) / (jnp.max(k) - jnp.min(k))
plt.scatter(
d.population,
d.total_tools,
s=40 * cex,
edgecolors=["none" if i == 1 else "b" for i in dat["cid"]],
facecolors=["none" if i == 0 else "b" for i in dat["cid"]],
)
plt.gca().set(xlabel="population", ylabel="total tools", xlim=(0, 300000), ylim=(0, 75))
ns = 100
P_seq = jnp.linspace(-5, 3, num=ns)
# 1.53 is sd of log(population)
# 9 is mean of log(population)
pop_seq = jnp.exp(P_seq * 1.53 + 9)
lambda_ = Predictive(m11_10.sampler.model, m11_10.get_samples())(
random.PRNGKey(1), P=P_seq, cid=0
)["lambda"]
lmu = jnp.mean(lambda_, 0)
lci = jnp.percentile(lambda_, jnp.array([5.5, 94.5]), 0)
plt.plot(pop_seq, lmu, "k--", lw=1.5)
plt.fill_between(pop_seq, lci[0], lci[1], color="k", alpha=0.2)
lambda_ = Predictive(m11_10.sampler.model, m11_10.get_samples())(
random.PRNGKey(1), P=P_seq, cid=1
)["lambda"]
lmu = jnp.mean(lambda_, 0)
lci = jnp.percentile(lambda_, jnp.array([5.5, 94.5]), 0)
plt.plot(pop_seq, lmu, "k", lw=1.5)
plt.fill_between(pop_seq, lci[0], lci[1], color="k", alpha=0.2)
plt.show()
dat2 = dict(T=d.total_tools.values, P=d.population.values, cid=d.contact_id.values)
def model(cid, P, T):
a = numpyro.sample("a", dist.Normal(1, 1).expand([2]))
b = numpyro.sample("b", dist.Exponential(1).expand([2]))
g = numpyro.sample("g", dist.Exponential(1))
lambda_ = jnp.exp(a[cid]) * jnp.power(P, b[cid]) / g
numpyro.sample("T", dist.Poisson(lambda_), obs=T)
m11_11 = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=4)
m11_11.run(random.PRNGKey(0), **dat2)
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
num_days = 30
y = dist.Poisson(1.5).sample(random.PRNGKey(0), (num_days,))
num_weeks = 4
y_new = dist.Poisson(0.5 * 7).sample(random.PRNGKey(0), (num_weeks,))
y_all = jnp.concatenate([y, y_new])
exposure = jnp.concatenate([jnp.repeat(1, 30), jnp.repeat(7, 4)])
monastery = jnp.concatenate([jnp.repeat(0, 30), jnp.repeat(1, 4)])
d = pd.DataFrame.from_dict(dict(y=y_all, days=exposure, monastery=monastery))
# compute the offset
d["log_days"] = d.days.apply(math.log)
def model(log_days, monastery, y):
a = numpyro.sample("a", dist.Normal(0, 1))
b = numpyro.sample("b", dist.Normal(0, 1))
lambda_ = jnp.exp(log_days + a + b * monastery)
numpyro.sample("T", dist.Poisson(lambda_), obs=y)
m11_12 = MCMC(NUTS(model), num_warmup=500, num_samples=500)
m11_12.run(random.PRNGKey(0), d.log_days.values, d.monastery.values, d.y.values)
sample: 100%|██████████| 1000/1000 [00:02<00:00, 418.26it/s, 1 steps of size 6.69e-01. acc. prob=0.93]
post = m11_12.get_samples()
lambda_old = jnp.exp(post["a"])
lambda_new = jnp.exp(post["a"] + post["b"])
print_summary(dict(lambda_old=lambda_old, lambda_new=lambda_new), 0.89, False)
mean std median 5.5% 94.5% n_eff r_hat lambda_new 0.53 0.14 0.52 0.31 0.73 503.83 1.00 lambda_old 1.70 0.22 1.67 1.38 2.04 241.48 1.00
# simulate career choices among 500 individuals
N = 500 # number of individuals
income = jnp.array([1, 2, 5]) # expected income of each career
score = 0.5 * income # scores for each career, based on income
# next line converts scores to probabilities
p = nn.softmax(score)
# now simulate choice
# outcome career holds event type values, not counts
career = jnp.repeat(jnp.nan, N) # empty vector of choices for each individual
# sample chosen career for each individual
for i in range(N):
career = career.at[i].set(
dist.Categorical(probs=p).sample(random.PRNGKey(34302 + i))
)
career = career.astype(jnp.int32)
def model_m11_13(N, K, career_income, career):
# intercepts
a = numpyro.sample("a", dist.Normal(0, 1).expand([K - 1]))
# association of income with choice
b = numpyro.sample("b", dist.HalfNormal(0.5))
s_1 = a[0] + b * career_income[0]
s_2 = a[1] + b * career_income[1]
s_3 = 0 # pivot
p = nn.softmax(jnp.stack([s_1, s_2, s_3]))
numpyro.sample("career", dist.Categorical(p), obs=career)
dat_list = dict(N=N, K=3, career=career, career_income=income)
m11_13 = MCMC(NUTS(model_m11_13), num_warmup=1000, num_samples=1000, num_chains=4)
m11_13.run(random.PRNGKey(0), **dat_list)
m11_13.print_summary()
0%| | 0/2000 [00:00<?, ?it/s]
0%| | 0/2000 [00:00<?, ?it/s]
0%| | 0/2000 [00:00<?, ?it/s]
0%| | 0/2000 [00:00<?, ?it/s]
mean std median 5.0% 95.0% n_eff r_hat a[0] -2.21 0.20 -2.19 -2.52 -1.88 84.60 1.05 a[1] -1.85 0.30 -1.80 -2.29 -1.42 56.85 1.08 b 0.15 0.14 0.11 0.00 0.34 50.57 1.09 Number of divergences: 94
post = m11_13.get_samples()
# set up logit scores
s1 = post["a"][:, 0] + post["b"] * income[0]
s2_orig = post["a"][:, 1] + post["b"] * income[1]
s2_new = post["a"][:, 1] + post["b"] * income[1] * 2
# compute probabilities for original and counterfactual
p_orig = vmap(lambda s1, s2: nn.softmax(jnp.stack([s1, s2, 0])))(s1, s2_orig)
p_new = vmap(lambda s1, s2: nn.softmax(jnp.stack([s1, s2, 0])))(s1, s2_new)
# summarize
p_diff = p_new[:, 1] - p_orig[:, 1]
print_summary(p_diff, 0.89, False)
mean std median 5.5% 94.5% n_eff r_hat Param:0 0.05 0.05 0.03 0.00 0.11 47.54 1.05
N = 500
# simulate family incomes for each individual
family_income = dist.Uniform().sample(random.PRNGKey(0), (N,))
# assign a unique coefficient for each type of event
b = jnp.array([-2, 0, 2])
career = jnp.repeat(jnp.nan, N) # empty vector of choices for each individual
for i in range(N):
score = 0.5 * jnp.arange(1, 4) + b * family_income[i]
p = nn.softmax(score)
career = career.at[i].set(
dist.Categorical(probs=p).sample(random.PRNGKey(34302 + i))
)
career = career.astype(jnp.int32)
def model_m11_14(N, K, family_income, career):
# intercepts
a = numpyro.sample("a", dist.Normal(0, 1.5).expand([K - 1]))
# coefficients on family income
b = numpyro.sample("b", dist.Normal(0, 1).expand([K - 1]))
s = a + b * family_income[..., None]
s_K = jnp.zeros((N, 1)) # the pivot
p = nn.softmax(jnp.concatenate([s, s_K], -1))
numpyro.sample("career", dist.Categorical(p), obs=career)
dat_list = dict(N=N, K=3, career=career, family_income=family_income)
m11_14 = MCMC(NUTS(model_m11_14), num_warmup=1000, num_samples=1000, num_chains=4)
m11_14.run(random.PRNGKey(0), **dat_list)
m11_14.print_summary()
0%| | 0/2000 [00:00<?, ?it/s]
0%| | 0/2000 [00:00<?, ?it/s]
0%| | 0/2000 [00:00<?, ?it/s]
0%| | 0/2000 [00:00<?, ?it/s]
mean std median 5.0% 95.0% n_eff r_hat a[0] -1.42 0.28 -1.42 -1.89 -0.97 1988.68 1.00 a[1] -0.57 0.20 -0.57 -0.88 -0.22 1873.82 1.00 b[0] -2.54 0.58 -2.54 -3.47 -1.60 2017.84 1.00 b[1] -2.17 0.41 -2.16 -2.82 -1.47 1873.55 1.00 Number of divergences: 0
UCBadmit = pd.read_csv("../data/UCBadmit.csv", sep=";")
d = UCBadmit
# binomial model of overall admission probability
def model(applications, admit):
a = numpyro.sample("a", dist.Normal(0, 100))
logit_p = a
numpyro.sample("admit", dist.Binomial(applications, logits=logit_p), obs=admit)
m_binom = AutoLaplaceApproximation(model)
svi = SVI(
model,
m_binom,
optim.Adam(1),
Trace_ELBO(),
applications=d.applications.values,
admit=d.admit.values,
)
svi_result = svi.run(random.PRNGKey(0), 1000)
p_binom = svi_result.params
# Poisson model of overall admission rate and rejection rate
d["rej"] = d.reject
def model(rej, admit):
a1, a2 = numpyro.sample("a", dist.Normal(0, 100).expand([2]))
lambda1 = jnp.exp(a1)
lambda2 = jnp.exp(a2)
numpyro.sample("rej", dist.Poisson(lambda2), obs=rej)
numpyro.sample("admit", dist.Poisson(lambda1), obs=admit)
m_pois = MCMC(NUTS(model), num_warmup=1000, num_samples=1000, num_chains=3)
m_pois.run(random.PRNGKey(0), d.rej.values, d.admit.values)
100%|██████████| 1000/1000 [00:00<00:00, 1852.24it/s, init loss: 734.6700, avg. loss [951-1000]: 478.5217]
0%| | 0/2000 [00:00<?, ?it/s]
0%| | 0/2000 [00:00<?, ?it/s]
0%| | 0/2000 [00:00<?, ?it/s]
expit(m_binom.median(p_binom)["a"])
DeviceArray(0.3877594, dtype=float32)
k = jnp.mean(m_pois.get_samples()["a"], 0)
a1 = k[0]
a2 = k[1]
jnp.exp(a1) / (jnp.exp(a1) + jnp.exp(a2))
DeviceArray(0.38763952, dtype=float32)