!pip install -q numpyro arviz
import os
import warnings
import arviz as az
import matplotlib.pyplot as plt
import pandas as pd
import jax.numpy as jnp
from jax import lax, random
from jax.scipy.special import expit
import numpyro
import numpyro.distributions as dist
from numpyro.diagnostics import effective_sample_size
from numpyro.infer import MCMC, NUTS, Predictive
if "SVG" in os.environ:
%config InlineBackend.figure_formats = ["svg"]
warnings.formatwarning = lambda message, category, *args, **kwargs: "{}: {}\n".format(
category.__name__, message
)
az.style.use("arviz-darkgrid")
numpyro.set_platform("cpu")
numpyro.set_host_device_count(4)
reedfrogs = pd.read_csv("../data/reedfrogs.csv", sep=";")
d = reedfrogs
d.head()
density | pred | size | surv | propsurv | |
---|---|---|---|---|---|
0 | 10 | no | big | 9 | 0.9 |
1 | 10 | no | big | 10 | 1.0 |
2 | 10 | no | big | 7 | 0.7 |
3 | 10 | no | big | 10 | 1.0 |
4 | 10 | no | small | 9 | 0.9 |
# make the tank cluster variable
d["tank"] = jnp.arange(d.shape[0])
dat = dict(S=d.surv.values, N=d.density.values, tank=d.tank.values)
# approximate posterior
def model(tank, N, S):
a = numpyro.sample("a", dist.Normal(0, 1.5), sample_shape=tank.shape)
logit_p = a[tank]
numpyro.sample("S", dist.Binomial(N, logits=logit_p), obs=S)
m13_1 = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=4)
m13_1.run(random.PRNGKey(0), **dat)
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
def model(tank, N, S):
a_bar = numpyro.sample("a_bar", dist.Normal(0, 1.5))
sigma = numpyro.sample("sigma", dist.Exponential(1))
a = numpyro.sample("a", dist.Normal(a_bar, sigma), sample_shape=tank.shape)
logit_p = a[tank]
numpyro.sample("S", dist.Binomial(N, logits=logit_p), obs=S)
m13_2 = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=4)
m13_2.run(random.PRNGKey(0), **dat)
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
az.compare(
{"m13.1": az.from_numpyro(m13_1), "m13.2": az.from_numpyro(m13_2)},
ic="waic",
scale="deviance",
)
UserWarning: The default method used to estimate the weights for each model,has changed from BB-pseudo-BMA to stacking UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail. See http://arxiv.org/abs/1507.04544 for details UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail. See http://arxiv.org/abs/1507.04544 for details
rank | waic | p_waic | d_waic | weight | se | dse | warning | waic_scale | |
---|---|---|---|---|---|---|---|---|---|
m13.2 | 0 | 200.729635 | 21.306688 | 0.000000 | 1.000000e+00 | 7.149090 | 0.000000 | True | deviance |
m13.1 | 1 | 215.670229 | 26.160694 | 14.940594 | 1.154632e-13 | 4.379542 | 3.857661 | True | deviance |
# extract NumPyro samples
post = m13_2.get_samples()
# compute median intercept for each tank
# also transform to probability with logistic
d["propsurv.est"] = expit(jnp.mean(post["a"], 0))
# display raw proportions surviving in each tank
plt.plot(jnp.arange(1, 49), d.propsurv, "o", alpha=0.5, zorder=3)
plt.gca().set(ylim=(-0.05, 1.05), xlabel="tank", ylabel="proportion survival")
plt.gca().set(xticks=[1, 16, 32, 48], xticklabels=[1, 16, 32, 48])
# overlay posterior means
plt.plot(jnp.arange(1, 49), d["propsurv.est"], "ko", mfc="w")
# mark posterior mean probability across tanks
plt.gca().axhline(y=jnp.mean(expit(post["a_bar"])), c="k", ls="--", lw=1)
# draw vertical dividers between tank densities
plt.gca().axvline(x=16.5, c="k", lw=0.5)
plt.gca().axvline(x=32.5, c="k", lw=0.5)
plt.annotate("small tanks", (8, 0), ha="center")
plt.annotate("medium tanks", (16 + 8, 0), ha="center")
plt.annotate("large tanks", (32 + 8, 0), ha="center")
plt.show()
# show first 100 populations in the posterior
plt.subplot(xlim=(-3, 4), ylim=(0, 0.35), xlabel="log-odds survive", ylabel="Density")
for i in range(100):
x = jnp.linspace(-3, 4, 101)
plt.plot(
x,
jnp.exp(dist.Normal(post["a_bar"][i], post["sigma"][i]).log_prob(x)),
"k",
alpha=0.2,
)
plt.show()
# sample 8000 imaginary tanks from the posterior distribution
idxs = random.randint(random.PRNGKey(1), (8000,), minval=0, maxval=1999)
sim_tanks = dist.Normal(post["a_bar"][idxs], post["sigma"][idxs]).sample(
random.PRNGKey(2)
)
# transform to probability and visualize
az.plot_kde(expit(sim_tanks), bw=0.3)
plt.show()
a_bar = 1.5
sigma = 1.5
nponds = 60
Ni = jnp.repeat(jnp.array([5, 10, 25, 35]), repeats=15)
a_pond = dist.Normal(a_bar, sigma).sample(random.PRNGKey(5005), (nponds,))
dsim = pd.DataFrame(dict(pond=range(1, nponds + 1), Ni=Ni, true_a=a_pond))
print(type(range(3)))
print(type(jnp.arange(3)))
<class 'range'> <class 'jaxlib.xla_extension.DeviceArray'>
dsim["Si"] = dist.Binomial(dsim.Ni.values, logits=dsim.true_a.values).sample(
random.PRNGKey(0)
)
dsim["p_nopool"] = dsim.Si / dsim.Ni
dat = dict(Si=dsim.Si.values, Ni=dsim.Ni.values, pond=dsim.pond.values - 1)
def model(pond, Ni, Si):
a_bar = numpyro.sample("a_bar", dist.Normal(0, 1.5))
sigma = numpyro.sample("sigma", dist.Exponential(1))
a_pond = numpyro.sample(
"a_pond", dist.Normal(a_bar, sigma), sample_shape=pond.shape
)
logit_p = a_pond[pond]
numpyro.sample("Si", dist.Binomial(Ni, logits=logit_p), obs=Si)
m13_3 = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=4)
m13_3.run(random.PRNGKey(0), **dat)
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
m13_3.print_summary(0.89)
mean std median 5.5% 94.5% n_eff r_hat a_bar 1.55 0.28 1.54 1.09 1.97 2640.58 1.00 a_pond[0] 3.01 1.36 2.89 0.89 5.11 2146.07 1.00 a_pond[1] 3.03 1.39 2.91 0.76 5.06 2518.07 1.00 a_pond[2] 0.74 0.89 0.71 -0.65 2.18 3881.46 1.00 a_pond[3] 3.01 1.31 2.90 1.14 5.21 2339.89 1.00 a_pond[4] 2.99 1.35 2.85 0.83 5.02 2502.75 1.00 a_pond[5] 2.97 1.35 2.82 0.85 5.04 2394.69 1.00 a_pond[6] -0.04 0.83 -0.03 -1.27 1.35 2422.02 1.00 a_pond[7] 3.01 1.36 2.89 0.88 5.06 1866.54 1.00 a_pond[8] 1.66 1.07 1.59 -0.10 3.25 3248.79 1.00 a_pond[9] 1.62 1.03 1.52 -0.08 3.12 2579.92 1.00 a_pond[10] -1.72 1.09 -1.64 -3.41 -0.03 3104.02 1.00 a_pond[11] -1.74 1.13 -1.63 -3.57 -0.09 2501.22 1.00 a_pond[12] 3.00 1.35 2.87 0.68 4.90 2192.21 1.00 a_pond[13] 3.05 1.39 2.88 0.85 5.25 2250.77 1.00 a_pond[14] -0.04 0.84 -0.03 -1.40 1.29 3785.95 1.00 a_pond[15] 1.51 0.75 1.46 0.33 2.67 3063.76 1.00 a_pond[16] 3.40 1.25 3.25 1.43 5.30 2503.37 1.00 a_pond[17] 2.24 0.94 2.16 0.72 3.66 3027.82 1.00 a_pond[18] 3.44 1.31 3.30 1.27 5.35 2444.29 1.00 a_pond[19] 0.59 0.69 0.57 -0.55 1.65 3086.36 1.00 a_pond[20] 1.55 0.80 1.50 0.26 2.74 2602.71 1.00 a_pond[21] 2.28 0.94 2.23 0.80 3.73 2771.97 1.00 a_pond[22] 0.58 0.64 0.57 -0.48 1.57 2846.15 1.00 a_pond[23] -1.07 0.70 -1.03 -2.23 -0.01 2388.72 1.00 a_pond[24] 3.43 1.29 3.29 1.38 5.28 1926.15 1.00 a_pond[25] 1.54 0.76 1.49 0.39 2.79 3489.25 1.00 a_pond[26] 3.43 1.30 3.26 1.43 5.53 1902.42 1.00 a_pond[27] 1.06 0.71 1.03 -0.10 2.12 2640.76 1.00 a_pond[28] 2.25 0.94 2.17 0.82 3.71 3871.51 1.00 a_pond[29] -0.21 0.63 -0.20 -1.24 0.77 3303.80 1.00 a_pond[30] -3.19 0.99 -3.06 -4.73 -1.70 1657.59 1.00 a_pond[31] -0.17 0.40 -0.16 -0.77 0.50 3286.27 1.00 a_pond[32] 0.82 0.43 0.81 0.10 1.47 3592.00 1.00 a_pond[33] 0.64 0.41 0.64 -0.05 1.25 2996.21 1.00 a_pond[34] -1.75 0.56 -1.70 -2.61 -0.84 3351.40 1.00 a_pond[35] 1.72 0.54 1.70 0.85 2.61 3152.36 1.00 a_pond[36] 0.82 0.43 0.81 0.22 1.60 2752.63 1.00 a_pond[37] 4.02 1.20 3.87 2.22 5.90 1909.16 1.00 a_pond[38] 4.01 1.16 3.88 2.18 5.62 2083.64 1.00 a_pond[39] 3.07 0.87 2.98 1.73 4.44 2040.58 1.00 a_pond[40] 3.06 0.88 2.98 1.65 4.35 1979.37 1.00 a_pond[41] 1.73 0.56 1.70 0.89 2.65 3205.70 1.00 a_pond[42] 3.06 0.87 2.97 1.70 4.40 2023.91 1.00 a_pond[43] 2.47 0.67 2.41 1.35 3.46 2418.37 1.00 a_pond[44] 3.07 0.85 2.97 1.74 4.36 2668.19 1.00 a_pond[45] -1.29 0.41 -1.26 -1.88 -0.62 3217.64 1.00 a_pond[46] 1.44 0.42 1.42 0.74 2.08 2950.12 1.00 a_pond[47] 0.23 0.33 0.22 -0.29 0.74 3520.28 1.00 a_pond[48] 0.58 0.36 0.57 -0.05 1.10 3651.30 1.00 a_pond[49] 3.33 0.81 3.26 2.05 4.60 2688.02 1.00 a_pond[50] 1.27 0.39 1.26 0.64 1.90 3220.04 1.00 a_pond[51] 3.35 0.83 3.30 2.10 4.65 2192.53 1.00 a_pond[52] 2.39 0.59 2.37 1.40 3.22 2908.23 1.00 a_pond[53] 0.71 0.35 0.70 0.14 1.28 3976.04 1.00 a_pond[54] 0.34 0.34 0.35 -0.18 0.91 3431.03 1.00 a_pond[55] 2.80 0.70 2.76 1.71 3.82 2395.95 1.00 a_pond[56] 2.10 0.54 2.06 1.28 2.97 2852.21 1.00 a_pond[57] 0.58 0.35 0.58 -0.00 1.13 3492.97 1.00 a_pond[58] 0.46 0.35 0.47 -0.02 1.07 3989.22 1.00 a_pond[59] 3.34 0.81 3.26 2.09 4.57 2403.33 1.00 sigma 1.87 0.26 1.86 1.46 2.25 877.63 1.00 Number of divergences: 0
post = m13_3.get_samples()
dsim["p_partpool"] = jnp.mean(expit(post["a_pond"]), 0)
dsim["p_true"] = expit(dsim.true_a.values)
nopool_error = (dsim.p_nopool - dsim.p_true).abs()
partpool_error = (dsim.p_partpool - dsim.p_true).abs()
plt.scatter(range(1, 61), nopool_error, label="nopool", alpha=0.8)
plt.gca().set(xlabel="pond", ylabel="absolute error")
plt.scatter(
range(1, 61),
partpool_error,
label="partpool",
s=50,
edgecolor="black",
facecolor="none",
)
plt.legend()
plt.show()
dsim["nopool_error"] = nopool_error
dsim["partpool_error"] = partpool_error
nopool_avg = dsim.groupby("Ni")["nopool_error"].mean()
partpool_avg = dsim.groupby("Ni")["partpool_error"].mean()
a_bar = 1.5
sigma = 1.5
nponds = 60
Ni = jnp.repeat(jnp.array([5, 10, 25, 35]), repeats=15)
a_pond = dist.Normal(a_bar, sigma).sample(random.PRNGKey(5006), (nponds,))
dsim = pd.DataFrame(dict(pond=range(1, nponds + 1), Ni=Ni, true_a=a_pond))
dsim["Si"] = dist.Binomial(dsim.Ni.values, logits=dsim.true_a.values).sample(
random.PRNGKey(0)
)
dsim["p_nopool"] = dsim.Si / dsim.Ni
newdat = dict(Si=dsim.Si.values, Ni=dsim.Ni.values, pond=dsim.pond.values - 1)
m13_3new = MCMC(
NUTS(m13_3.sampler.model), num_warmup=1000, num_samples=1000, num_chains=4
)
m13_3new.run(random.PRNGKey(0), **newdat)
post = m13_3new.get_samples()
dsim["p_partpool"] = jnp.mean(expit(post["a_pond"]), 0)
dsim["p_true"] = expit(dsim.true_a.values)
nopool_error = (dsim.p_nopool - dsim.p_true).abs()
partpool_error = (dsim.p_partpool - dsim.p_true).abs()
plt.scatter(range(1, 61), nopool_error, label="nopool", alpha=0.8)
plt.gca().set(xlabel="pond", ylabel="absolute error")
plt.scatter(
range(1, 61),
partpool_error,
label="partpool",
s=50,
edgecolor="black",
facecolor="none",
)
plt.legend()
plt.show()
0%| | 0/2000 [00:00<?, ?it/s]
0%| | 0/2000 [00:00<?, ?it/s]
0%| | 0/2000 [00:00<?, ?it/s]
0%| | 0/2000 [00:00<?, ?it/s]
chimpanzees = pd.read_csv("../data/chimpanzees.csv", sep=";")
d = chimpanzees
d["treatment"] = 1 + d.prosoc_left + 2 * d.condition
dat_list = dict(
pulled_left=d.pulled_left.values,
actor=d.actor.values - 1,
block_id=d.block.values - 1,
treatment=d.treatment.values - 1,
)
def model(actor, block_id, treatment, pulled_left=None, link=False):
# hyper-priors
a_bar = numpyro.sample("a_bar", dist.Normal(0, 1.5))
sigma_a = numpyro.sample("sigma_a", dist.Exponential(1))
sigma_g = numpyro.sample("sigma_g", dist.Exponential(1))
# adaptive priors
a = numpyro.sample("a", dist.Normal(a_bar, sigma_a), sample_shape=(7,))
g = numpyro.sample("g", dist.Normal(0, sigma_g), sample_shape=(6,))
b = numpyro.sample("b", dist.Normal(0, 0.5), sample_shape=(4,))
logit_p = a[actor] + g[block_id] + b[treatment]
if link:
numpyro.deterministic("p", expit(logit_p))
numpyro.sample("pulled_left", dist.Binomial(logits=logit_p), obs=pulled_left)
m13_4 = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=4)
m13_4.run(random.PRNGKey(0), **dat_list)
print("Number of divergences:", m13_4.get_extra_fields()["diverging"].sum())
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
Number of divergences: 7
m13_4.print_summary()
post = m13_4.get_samples(group_by_chain=True)
az.plot_forest(post, combined=True, hdi_prob=0.89) # also plot
plt.show()
mean std median 5.0% 95.0% n_eff r_hat a[0] -0.38 0.37 -0.38 -0.98 0.22 597.32 1.00 a[1] 4.67 1.29 4.49 2.78 6.67 948.84 1.00 a[2] -0.69 0.37 -0.70 -1.30 -0.07 566.81 1.00 a[3] -0.70 0.39 -0.69 -1.36 -0.09 512.00 1.00 a[4] -0.39 0.37 -0.39 -1.01 0.19 546.34 1.00 a[5] 0.56 0.38 0.56 -0.08 1.15 607.70 1.00 a[6] 2.09 0.47 2.10 1.26 2.81 757.08 1.00 a_bar 0.57 0.72 0.57 -0.61 1.74 1482.01 1.00 b[0] -0.11 0.30 -0.11 -0.59 0.37 614.87 1.00 b[1] 0.41 0.30 0.41 -0.07 0.93 623.30 1.00 b[2] -0.46 0.31 -0.45 -0.99 -0.01 574.28 1.00 b[3] 0.30 0.30 0.31 -0.18 0.81 620.34 1.00 g[0] -0.18 0.23 -0.13 -0.56 0.10 556.08 1.01 g[1] 0.05 0.20 0.03 -0.24 0.41 613.18 1.01 g[2] 0.06 0.19 0.03 -0.26 0.38 596.79 1.00 g[3] 0.02 0.19 0.01 -0.28 0.34 947.51 1.00 g[4] -0.03 0.19 -0.01 -0.35 0.27 933.78 1.00 g[5] 0.13 0.22 0.09 -0.21 0.47 485.27 1.00 sigma_a 2.03 0.65 1.93 1.06 2.96 1194.46 1.00 sigma_g 0.23 0.18 0.19 0.00 0.47 309.51 1.01 Number of divergences: 7
def model(actor, treatment, pulled_left):
a_bar = numpyro.sample("a_bar", dist.Normal(0, 1.5))
sigma_a = numpyro.sample("sigma_a", dist.Exponential(1))
a = numpyro.sample("a", dist.Normal(a_bar, sigma_a), sample_shape=(7,))
b = numpyro.sample("b", dist.Normal(0, 0.5), sample_shape=(4,))
logit_p = a[actor] + b[treatment]
numpyro.sample("pulled_left", dist.Binomial(logits=logit_p), obs=pulled_left)
m13_5 = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=4)
m13_5.run(
random.PRNGKey(14),
dat_list["actor"],
dat_list["treatment"],
dat_list["pulled_left"],
)
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
az.compare(
{"m13.4": az.from_numpyro(m13_4), "m13.5": az.from_numpyro(m13_5)},
ic="waic",
scale="deviance",
)
UserWarning: The default method used to estimate the weights for each model,has changed from BB-pseudo-BMA to stacking
rank | waic | p_waic | d_waic | weight | se | dse | warning | waic_scale | |
---|---|---|---|---|---|---|---|---|---|
m13.5 | 0 | 531.103967 | 8.554003 | 0.000000 | 1.0 | 19.200204 | 0.000000 | False | deviance |
m13.4 | 1 | 532.297533 | 10.828377 | 1.193565 | 0.0 | 19.411863 | 1.873273 | False | deviance |
def model(actor, block_id, treatment, pulled_left):
a_bar = numpyro.sample("a_bar", dist.Normal(0, 1.5))
sigma_a = numpyro.sample("sigma_a", dist.Exponential(1))
sigma_g = numpyro.sample("sigma_g", dist.Exponential(1))
sigma_b = numpyro.sample("sigma_b", dist.Exponential(1))
a = numpyro.sample("a", dist.Normal(a_bar, sigma_a), sample_shape=(7,))
g = numpyro.sample("g", dist.Normal(0, sigma_g), sample_shape=(6,))
b = numpyro.sample("b", dist.Normal(0, sigma_b), sample_shape=(4,))
logit_p = a[actor] + g[block_id] + b[treatment]
numpyro.sample("pulled_left", dist.Binomial(logits=logit_p), obs=pulled_left)
m13_6 = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=4)
m13_6.run(random.PRNGKey(16), **dat_list)
print("Number of divergences:", m13_6.get_extra_fields()["diverging"].sum())
{
"m13.4": jnp.mean(m13_4.get_samples()["b"], 0),
"m13.6": jnp.mean(m13_6.get_samples()["b"], 0),
}
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
Number of divergences: 10
{'m13.4': DeviceArray([-0.11273014, 0.41450262, -0.45547444, 0.30242366], dtype=float32), 'm13.6': DeviceArray([-0.11401905, 0.3812172 , -0.44742095, 0.27241236], dtype=float32)}
def model():
v = numpyro.sample("v", dist.Normal(0, 3))
x = numpyro.sample("x", dist.Normal(0, jnp.exp(v)))
m13_7 = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=4)
m13_7.run(random.PRNGKey(0))
m13_7.print_summary()
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
mean std median 5.0% 95.0% n_eff r_hat v 1.24 2.58 1.22 -2.75 4.49 4.12 1.59 x 43.93 349.99 0.01 -43.64 51.78 49.60 1.06 Number of divergences: 201
def model():
v = numpyro.sample("v", dist.Normal(0, 3))
z = numpyro.sample("z", dist.Normal(0, 1))
numpyro.deterministic("x", z * jnp.exp(v))
m13_7nc = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=4)
m13_7nc.run(random.PRNGKey(0))
m13_7nc.print_summary(exclude_deterministic=False)
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
mean std median 5.0% 95.0% n_eff r_hat v -0.04 3.00 -0.13 -4.85 5.01 2019.83 1.00 x 2.92 149.31 -0.00 -26.96 30.74 1719.69 1.00 z -0.02 0.98 -0.01 -1.66 1.57 1963.93 1.00 Number of divergences: 0
m13_4b = MCMC(
NUTS(m13_4.sampler.model, target_accept_prob=0.99),
num_warmup=500,
num_samples=500,
num_chains=4,
)
m13_4b.run(random.PRNGKey(13), **dat_list)
jnp.sum(m13_4b.get_extra_fields()["diverging"])
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
DeviceArray(11, dtype=int32)
def model(actor, block_id, treatment, pulled_left):
a_bar = numpyro.sample("a_bar", dist.Normal(0, 1.5))
sigma_a = numpyro.sample("sigma_a", dist.Exponential(1))
sigma_g = numpyro.sample("sigma_g", dist.Exponential(1))
z = numpyro.sample("z", dist.Normal(0, 1), sample_shape=(7,))
x = numpyro.sample("x", dist.Normal(0, 1), sample_shape=(6,))
b = numpyro.sample("b", dist.Normal(0, 0.5), sample_shape=(4,))
logit_p = a_bar + z[actor] * sigma_a + x[block_id] * sigma_g + b[treatment]
numpyro.sample("pulled_left", dist.Binomial(logits=logit_p), obs=pulled_left)
m13_4nc = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=4)
m13_4nc.run(random.PRNGKey(16), **dat_list)
print("Number of divergences:", m13_4nc.get_extra_fields()["diverging"].sum())
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
Number of divergences: 0
neff_c = {
k: effective_sample_size(v)
for k, v in m13_4.get_samples(group_by_chain=True).items()
}
neff_nc = {
k: effective_sample_size(v)
for k, v in m13_4nc.get_samples(group_by_chain=True).items()
}
par_names = []
keys_c = ["b", "a", "g", "a_bar", "sigma_a", "sigma_g"]
keys_nc = ["b", "z", "x", "a_bar", "sigma_a", "sigma_g"]
for k in keys_c:
if jnp.ndim(neff_c[k]) == 0:
par_names += [k]
else:
par_names += [k + "[{}]".format(i) for i in range(neff_c[k].size)]
neff_c = jnp.concatenate([neff_c[k].reshape(-1) for k in keys_c])
neff_nc = jnp.concatenate([neff_nc[k].reshape(-1) for k in keys_nc])
neff_table = pd.DataFrame(dict(neff_c=neff_c, neff_nc=neff_nc))
neff_table.index = par_names
neff_table.round()
neff_c | neff_nc | |
---|---|---|
b[0] | 615.0 | 1095.0 |
b[1] | 623.0 | 1072.0 |
b[2] | 574.0 | 904.0 |
b[3] | 620.0 | 987.0 |
a[0] | 597.0 | 382.0 |
a[1] | 949.0 | 1045.0 |
a[2] | 567.0 | 379.0 |
a[3] | 512.0 | 383.0 |
a[4] | 546.0 | 369.0 |
a[5] | 608.0 | 367.0 |
a[6] | 757.0 | 640.0 |
g[0] | 556.0 | 1675.0 |
g[1] | 613.0 | 2287.0 |
g[2] | 597.0 | 1947.0 |
g[3] | 948.0 | 2062.0 |
g[4] | 934.0 | 2232.0 |
g[5] | 485.0 | 1443.0 |
a_bar | 1482.0 | 429.0 |
sigma_a | 1194.0 | 647.0 |
sigma_g | 310.0 | 849.0 |
chimp = 2
d_pred = dict(
actor=jnp.repeat(chimp, 4) - 1,
treatment=jnp.arange(4),
block_id=jnp.repeat(1, 4) - 1,
)
p = Predictive(m13_4.sampler.model, m13_4.get_samples())(
random.PRNGKey(0), link=True, **d_pred
)["p"]
p_mu = jnp.mean(p, 0)
p_ci = jnp.percentile(p, q=jnp.array([5.5, 94.5]), axis=0)
post = m13_4.get_samples()
{k: v.reshape(-1)[:5] for k, v in post.items()}
{'a': DeviceArray([ 0.17939359, 3.2472124 , -0.42817664, -0.3104607 , 0.09759563], dtype=float32), 'a_bar': DeviceArray([ 0.57802844, 0.5723409 , 0.24400227, -0.14573385, 0.60223407], dtype=float32), 'b': DeviceArray([-0.35315567, 0.07911116, -0.7404461 , -0.29147664, -0.1611932 ], dtype=float32), 'g': DeviceArray([-0.06372303, -0.07348639, -0.07706347, -0.00371453, 0.11766291], dtype=float32), 'sigma_a': DeviceArray([2.015479 , 2.4767263, 3.077415 , 1.6289241, 1.7870352], dtype=float32), 'sigma_g': DeviceArray([0.05992268, 0.21187922, 0.09104711, 0.2905614 , 0.26771992], dtype=float32)}
az.plot_kde(post["a"][:, 4])
plt.show()
def p_link(treatment, actor=0, block_id=0):
a, g, b = post["a"], post["g"], post["b"]
logodds = a[:, actor] + g[:, block_id] + b[:, treatment]
return expit(logodds)
p_raw = lax.map(lambda i: p_link(i, actor=1, block_id=0), jnp.arange(4))
p_mu = jnp.mean(p_raw, 0)
p_ci = jnp.percentile(p_raw, jnp.array([5.5, 94.5]), 0)
def p_link_abar(treatment):
logodds = post["a_bar"] + post["b"][:, treatment]
return expit(logodds)
p_raw = lax.map(p_link_abar, jnp.arange(4))
p_mu = jnp.mean(p_raw, 1)
p_ci = jnp.percentile(p_raw, jnp.array([5.5, 94.5]), 1)
plt.subplot(
xlabel="treatment", ylabel="proportion pulled left", ylim=(0, 1), xlim=(0.9, 4.1)
)
plt.gca().set(xticks=range(1, 5), xticklabels=["R/N", "L/N", "R/P", "L/P"])
plt.plot(range(1, 5), p_mu)
plt.fill_between(range(1, 5), p_ci[0], p_ci[1], color="k", alpha=0.2)
plt.show()
a_sim = dist.Normal(post["a_bar"], post["sigma_a"]).sample(random.PRNGKey(0))
def p_link_asim(treatment):
logodds = a_sim + post["b"][:, treatment]
return expit(logodds)
p_raw_asim = lax.map(p_link_asim, jnp.arange(4))
plt.subplot(
xlabel="treatment", ylabel="proportion pulled left", ylim=(0, 1), xlim=(0.9, 4.1)
)
plt.gca().set(xticks=range(1, 5), xticklabels=["R/N", "L/N", "R/P", "L/P"])
for i in range(100):
plt.plot(range(1, 5), p_raw_asim[:, i], color="k", alpha=0.25)
bangladesh = pd.read_csv("../data/bangladesh.csv", sep=";")
d = bangladesh
jnp.sort(d.district.unique())
DeviceArray([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 55, 56, 57, 58, 59, 60, 61], dtype=int32)
d["district_id"] = d.district.astype("category").cat.codes
jnp.sort(d.district_id.unique())
DeviceArray([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59], dtype=int8)