!pip install -q numpyro arviz daft networkx
import collections
import itertools
import math
import os
import arviz as az
import daft
import matplotlib.pyplot as plt
import networkx as nx
import pandas as pd
import jax.numpy as jnp
from jax import random
import numpyro
import numpyro.distributions as dist
import numpyro.optim as optim
from numpyro.diagnostics import print_summary
from numpyro.infer import Predictive, SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoLaplaceApproximation
if "SVG" in os.environ:
%config InlineBackend.figure_formats = ["svg"]
az.style.use("arviz-darkgrid")
numpyro.set_platform("cpu")
# load data and copy
WaffleDivorce = pd.read_csv("../data/WaffleDivorce.csv", sep=";")
d = WaffleDivorce
# standardize variables
d["A"] = d.MedianAgeMarriage.pipe(lambda x: (x - x.mean()) / x.std())
d["D"] = d.Divorce.pipe(lambda x: (x - x.mean()) / x.std())
d.MedianAgeMarriage.std()
1.2436303013880823
def model(A, D=None):
a = numpyro.sample("a", dist.Normal(0, 0.2))
bA = numpyro.sample("bA", dist.Normal(0, 0.5))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = numpyro.deterministic("mu", a + bA * A)
numpyro.sample("D", dist.Normal(mu, sigma), obs=D)
m5_1 = AutoLaplaceApproximation(model)
svi = SVI(model, m5_1, optim.Adam(1), Trace_ELBO(), A=d.A.values, D=d.D.values)
svi_result = svi.run(random.PRNGKey(0), 1000)
p5_1 = svi_result.params
100%|██████████| 1000/1000 [00:00<00:00, 1196.77it/s, init loss: 2138.6682, avg. loss [951-1000]: 59.4385]
predictive = Predictive(m5_1.model, num_samples=1000, return_sites=["mu"])
prior_pred = predictive(random.PRNGKey(10), A=jnp.array([-2, 2]))
mu = prior_pred["mu"]
plt.subplot(xlim=(-2, 2), ylim=(-2, 2))
for i in range(20):
plt.plot([-2, 2], mu[i], "k", alpha=0.4)
# compute percentile interval of mean
A_seq = jnp.linspace(start=-3, stop=3.2, num=30)
post = m5_1.sample_posterior(random.PRNGKey(1), p5_1, sample_shape=(1000,))
post.pop("mu")
post_pred = Predictive(m5_1.model, post)(random.PRNGKey(2), A=A_seq)
mu = post_pred["mu"]
mu_mean = jnp.mean(mu, 0)
mu_PI = jnp.percentile(mu, q=jnp.array([5.5, 94.5]), axis=0)
# plot it all
az.plot_pair(d[["D", "A"]].to_dict(orient="list"))
plt.plot(A_seq, mu_mean, "k")
plt.fill_between(A_seq, mu_PI[0], mu_PI[1], color="k", alpha=0.2)
plt.show()
d["M"] = d.Marriage.pipe(lambda x: (x - x.mean()) / x.std())
def model(M, D=None):
a = numpyro.sample("a", dist.Normal(0, 0.2))
bM = numpyro.sample("bM", dist.Normal(0, 0.5))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = a + bM * M
numpyro.sample("D", dist.Normal(mu, sigma), obs=D)
m5_2 = AutoLaplaceApproximation(model)
svi = SVI(model, m5_2, optim.Adam(1), Trace_ELBO(), M=d.M.values, D=d.D.values)
svi_result = svi.run(random.PRNGKey(0), 1000)
p5_2 = svi_result.params
100%|██████████| 1000/1000 [00:00<00:00, 1218.38it/s, init loss: 962.7464, avg. loss [951-1000]: 66.1313]
dag5_1 = nx.DiGraph()
dag5_1.add_edges_from([("A", "D"), ("A", "M"), ("M", "D")])
pgm = daft.PGM()
coordinates = {"A": (0, 0), "D": (1, 1), "M": (2, 0)}
for node in dag5_1.nodes:
pgm.add_node(node, node, *coordinates[node])
for edge in dag5_1.edges:
pgm.add_edge(*edge)
with plt.rc_context({"figure.constrained_layout.use": False}):
pgm.render()
plt.gca().invert_yaxis()
plt.show()
DMA_dag2 = nx.DiGraph()
DMA_dag2.add_edges_from([("A", "D"), ("A", "M")])
conditional_independencies = collections.defaultdict(list)
for edge in itertools.combinations(sorted(DMA_dag2.nodes), 2):
remaining = sorted(set(DMA_dag2.nodes) - set(edge))
for size in range(len(remaining) + 1):
for subset in itertools.combinations(remaining, size):
if any(cond.issubset(set(subset)) for cond in conditional_independencies[edge]):
continue
if nx.d_separated(DMA_dag2, {edge[0]}, {edge[1]}, set(subset)):
conditional_independencies[edge].append(set(subset))
print(f"{edge[0]} _||_ {edge[1]}" + (f" | {' '.join(subset)}" if subset else ""))
D _||_ M | A
DMA_dag1 = nx.DiGraph()
DMA_dag1.add_edges_from([("A", "D"), ("A", "M"), ("M", "D")])
conditional_independencies = collections.defaultdict(list)
for edge in itertools.combinations(sorted(DMA_dag1.nodes), 2):
remaining = sorted(set(DMA_dag1.nodes) - set(edge))
for size in range(len(remaining) + 1):
for subset in itertools.combinations(remaining, size):
if any(cond.issubset(set(subset)) for cond in conditional_independencies[edge]):
continue
if nx.d_separated(DMA_dag1, {edge[0]}, {edge[1]}, set(subset)):
conditional_independencies[edge].append(set(subset))
print(f"{edge[0]} _||_ {edge[1]}" + (f" | {' '.join(subset)}" if subset else ""))
def model(M, A, D=None):
a = numpyro.sample("a", dist.Normal(0, 0.2))
bM = numpyro.sample("bM", dist.Normal(0, 0.5))
bA = numpyro.sample("bA", dist.Normal(0, 0.5))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = numpyro.deterministic("mu", a + bM * M + bA * A)
numpyro.sample("D", dist.Normal(mu, sigma), obs=D)
m5_3 = AutoLaplaceApproximation(model)
svi = SVI(
model, m5_3, optim.Adam(1), Trace_ELBO(), M=d.M.values, A=d.A.values, D=d.D.values
)
svi_result = svi.run(random.PRNGKey(0), 1000)
p5_3 = svi_result.params
post = m5_3.sample_posterior(random.PRNGKey(1), p5_3, sample_shape=(1000,))
print_summary(post, 0.89, False)
100%|██████████| 1000/1000 [00:00<00:00, 1002.82it/s, init loss: 3201.7393, avg. loss [951-1000]: 59.5721]
mean std median 5.5% 94.5% n_eff r_hat a -0.00 0.10 -0.01 -0.16 0.14 1049.96 1.00 bA -0.61 0.16 -0.61 -0.86 -0.36 822.38 1.00 bM -0.06 0.16 -0.06 -0.31 0.19 984.99 1.00 mu[0] 0.36 0.13 0.37 0.15 0.57 921.63 1.00 mu[1] 0.32 0.21 0.32 -0.01 0.66 900.77 1.00 mu[2] 0.12 0.10 0.12 -0.03 0.28 995.35 1.00 mu[3] 0.76 0.21 0.75 0.43 1.10 911.65 1.00 mu[4] -0.35 0.12 -0.35 -0.52 -0.14 1070.56 1.00 mu[5] 0.12 0.15 0.12 -0.13 0.35 861.01 1.00 mu[6] -0.71 0.17 -0.70 -0.95 -0.42 1053.26 1.00 mu[7] -0.31 0.20 -0.32 -0.63 0.03 865.14 1.00 mu[8] -1.74 0.40 -1.74 -2.28 -1.00 804.96 1.00 mu[9] -0.12 0.14 -0.12 -0.37 0.08 1072.06 1.00 mu[10] 0.04 0.12 0.04 -0.14 0.23 858.83 1.00 mu[11] -0.49 0.30 -0.49 -0.91 0.05 872.77 1.00 mu[12] 1.30 0.27 1.30 0.87 1.72 867.10 1.00 mu[13] -0.43 0.13 -0.43 -0.61 -0.21 1125.00 1.00 mu[14] 0.18 0.11 0.18 0.00 0.35 987.73 1.00 mu[15] 0.30 0.11 0.30 0.11 0.47 909.55 1.00 mu[16] 0.48 0.13 0.49 0.27 0.70 875.38 1.00 mu[17] 0.58 0.15 0.58 0.35 0.82 867.29 1.00 mu[18] 0.07 0.10 0.06 -0.09 0.21 1002.97 1.00 mu[19] -0.07 0.27 -0.07 -0.50 0.34 1018.90 1.00 mu[20] -0.58 0.15 -0.58 -0.80 -0.32 1013.12 1.00 mu[21] -1.13 0.24 -1.13 -1.47 -0.74 928.11 1.00 mu[22] -0.11 0.16 -0.11 -0.38 0.12 1058.69 1.00 mu[23] -0.05 0.21 -0.05 -0.38 0.27 1029.22 1.00 mu[24] 0.13 0.11 0.13 -0.05 0.30 1018.21 1.00 mu[25] 0.24 0.15 0.24 0.03 0.50 987.36 1.00 mu[26] 0.20 0.14 0.19 -0.04 0.42 1004.82 1.00 mu[27] 0.33 0.14 0.32 0.10 0.53 944.68 1.00 mu[28] -0.31 0.14 -0.31 -0.53 -0.09 1106.73 1.00 mu[29] -0.72 0.19 -0.72 -1.03 -0.42 1099.72 1.00 mu[30] 0.12 0.10 0.12 -0.03 0.28 992.25 1.00 mu[31] -1.09 0.24 -1.10 -1.44 -0.70 891.79 1.00 mu[32] 0.17 0.10 0.17 0.00 0.33 973.21 1.00 mu[33] 0.26 0.24 0.26 -0.13 0.63 903.59 1.00 mu[34] -0.07 0.15 -0.07 -0.34 0.14 1057.21 1.00 mu[35] 0.75 0.18 0.75 0.49 1.05 865.16 1.00 mu[36] 0.04 0.11 0.04 -0.15 0.20 1071.57 1.00 mu[37] -0.44 0.17 -0.44 -0.70 -0.17 1096.13 1.00 mu[38] -0.97 0.21 -0.97 -1.31 -0.63 1057.69 1.00 mu[39] -0.14 0.11 -0.14 -0.32 0.04 1109.30 1.00 mu[40] 0.22 0.11 0.22 0.04 0.39 962.55 1.00 mu[41] 0.43 0.16 0.42 0.18 0.69 929.76 1.00 mu[42] 0.39 0.12 0.40 0.18 0.57 890.80 1.00 mu[43] 1.19 0.30 1.20 0.72 1.70 930.35 1.00 mu[44] -0.36 0.15 -0.35 -0.59 -0.12 1105.88 1.00 mu[45] -0.18 0.11 -0.18 -0.34 0.01 937.86 1.00 mu[46] 0.05 0.10 0.05 -0.12 0.21 891.13 1.00 mu[47] 0.48 0.13 0.48 0.27 0.70 875.38 1.00 mu[48] -0.08 0.14 -0.08 -0.32 0.13 1065.16 1.00 mu[49] 0.74 0.34 0.73 0.16 1.22 951.13 1.00 sigma 0.80 0.08 0.79 0.68 0.92 971.25 1.00
coeftab = {
"m5.1": m5_1.sample_posterior(random.PRNGKey(1), p5_1, sample_shape=(1, 1000)),
"m5.2": m5_2.sample_posterior(random.PRNGKey(2), p5_2, sample_shape=(1, 1000)),
"m5.3": m5_3.sample_posterior(random.PRNGKey(3), p5_3, sample_shape=(1, 1000)),
}
az.plot_forest(
list(coeftab.values()),
model_names=list(coeftab.keys()),
var_names=["bA", "bM"],
hdi_prob=0.89,
)
plt.show()
N = 50 # number of simulated States
age = dist.Normal().sample(random.PRNGKey(0), sample_shape=(N,)) # sim A
mar = dist.Normal(age).sample(random.PRNGKey(1)) # sim A -> M
div = dist.Normal(age).sample(random.PRNGKey(2)) # sim A -> D
def model(A, M=None):
a = numpyro.sample("a", dist.Normal(0, 0.2))
bAM = numpyro.sample("bA", dist.Normal(0, 0.5))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = numpyro.deterministic("mu", a + bAM * A)
numpyro.sample("M", dist.Normal(mu, sigma), obs=M)
m5_4 = AutoLaplaceApproximation(model)
svi = SVI(model, m5_4, optim.Adam(0.1), Trace_ELBO(), A=d.A.values, M=d.M.values)
svi_result = svi.run(random.PRNGKey(0), 1000)
p5_4 = svi_result.params
100%|██████████| 1000/1000 [00:00<00:00, 1218.61it/s, init loss: 2288.6685, avg. loss [951-1000]: 52.6188]
post = m5_4.sample_posterior(random.PRNGKey(1), p5_4, sample_shape=(1000,))
post.pop("mu")
post_pred = Predictive(m5_4.model, post)(random.PRNGKey(2), A=d.A.values)
mu = post_pred["mu"]
mu_mean = jnp.mean(mu, 0)
mu_resid = d.M.values - mu_mean
# call predictive without specifying new data
# so it uses original data
post = m5_3.sample_posterior(random.PRNGKey(1), p5_3, sample_shape=(int(1e4),))
post.pop("mu")
post_pred = Predictive(m5_3.model, post)(random.PRNGKey(2), M=d.M.values, A=d.A.values)
mu = post_pred["mu"]
# summarize samples across cases
mu_mean = jnp.mean(mu, 0)
mu_PI = jnp.percentile(mu, q=jnp.array([5.5, 94.5]), axis=0)
# simulate observations
# again no new data, so uses original data
D_sim = post_pred["D"]
D_PI = jnp.percentile(D_sim, q=jnp.array([5.5, 94.5]), axis=0)
ax = plt.subplot(
ylim=(float(mu_PI.min()), float(mu_PI.max())),
xlabel="Observed divorce",
ylabel="Predicted divorce",
)
plt.plot(d.D, mu_mean, "o")
x = jnp.linspace(mu_PI.min(), mu_PI.max(), 101)
plt.plot(x, x, "--")
for i in range(d.shape[0]):
plt.plot([d.D[i]] * 2, mu_PI[:, i], "b")
fig = plt.gcf()
for i in range(d.shape[0]):
if d.Loc[i] in ["ID", "UT", "RI", "ME"]:
ax.annotate(
d.Loc[i], (d.D[i], mu_mean[i]), xytext=(-25, -5), textcoords="offset pixels"
)
fig
N = 100 # number of cases
# x_real as Gaussian with mean 0 and stddev 1
x_real = dist.Normal().sample(random.PRNGKey(0), (N,))
# x_spur as Gaussian with mean=x_real
x_spur = dist.Normal(x_real).sample(random.PRNGKey(1))
# y as Gaussian with mean=x_real
y = dist.Normal(x_real).sample(random.PRNGKey(2))
# bind all together in data frame
d = pd.DataFrame({"y": y, "x_real": x_real, "x_spur": x_spur})
WaffleDivorce = pd.read_csv("../data/WaffleDivorce.csv", sep=";")
d = WaffleDivorce
d["A"] = d.MedianAgeMarriage.pipe(lambda x: (x - x.mean()) / x.std())
d["D"] = d.Divorce.pipe(lambda x: (x - x.mean()) / x.std())
d["M"] = d.Marriage.pipe(lambda x: (x - x.mean()) / x.std())
def model(A, M=None, D=None):
# A -> M
aM = numpyro.sample("aM", dist.Normal(0, 0.2))
bAM = numpyro.sample("bAM", dist.Normal(0, 0.5))
sigma_M = numpyro.sample("sigma_M", dist.Exponential(1))
mu_M = aM + bAM * A
M = numpyro.sample("M", dist.Normal(mu_M, sigma_M), obs=M)
# A -> D <- M
a = numpyro.sample("a", dist.Normal(0, 0.2))
bM = numpyro.sample("bM", dist.Normal(0, 0.5))
bA = numpyro.sample("bA", dist.Normal(0, 0.5))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = a + bM * M + bA * A
numpyro.sample("D", dist.Normal(mu, sigma), obs=D)
m5_3_A = AutoLaplaceApproximation(model)
svi = SVI(
model,
m5_3_A,
optim.Adam(0.1),
Trace_ELBO(),
A=d.A.values,
M=d.M.values,
D=d.D.values,
)
svi_result = svi.run(random.PRNGKey(0), 1000)
p5_3_A = svi_result.params
100%|██████████| 1000/1000 [00:01<00:00, 734.26it/s, init loss: 10480.9580, avg. loss [951-1000]: 112.1909]
A_seq = jnp.linspace(-2, 2, num=30)
# prep data
sim_dat = dict(A=A_seq)
# simulate M and then D, using A_seq
post = m5_3_A.sample_posterior(random.PRNGKey(1), p5_3_A, sample_shape=(1000,))
s = Predictive(m5_3_A.model, post)(random.PRNGKey(2), **sim_dat)
plt.plot(sim_dat["A"], jnp.mean(s["D"], 0))
plt.gca().set(ylim=(-2, 2), xlabel="manipulated A", ylabel="counterfactual D")
plt.fill_between(
sim_dat["A"],
*jnp.percentile(s["D"], q=jnp.array([5.5, 94.5]), axis=0),
color="k",
alpha=0.2
)
plt.title("Total counterfactual effect of A on D")
plt.show()
# new data frame, standardized to mean 26.1 and stddev 1.24
sim2_dat = dict(A=(jnp.array([20, 30]) - 26.1) / 1.24)
s2 = Predictive(m5_3_A.model, post, return_sites=["M", "D"])(
random.PRNGKey(2), **sim2_dat
)
jnp.mean(s2["D"][:, 1] - s2["D"][:, 0])
DeviceArray(-4.6818223, dtype=float32)
sim_dat = dict(M=jnp.linspace(-2, 2, num=30), A=0)
s = Predictive(m5_3_A.model, post)(random.PRNGKey(2), **sim_dat)["D"]
plt.plot(sim_dat["M"], jnp.mean(s, 0))
plt.gca().set(ylim=(-2, 2), xlabel="manipulated A", ylabel="counterfactual D")
plt.fill_between(
sim_dat["M"],
*jnp.percentile(s, q=jnp.array([5.5, 94.5]), axis=0),
color="k",
alpha=0.2
)
plt.title("Total counterfactual effect of M on D")
plt.show()
A_seq = jnp.linspace(-2, 2, num=30)
post = m5_3_A.sample_posterior(random.PRNGKey(1), p5_3_A, sample_shape=(1000,))
post = {k: v[..., None] for k, v in post.items()}
M_sim = dist.Normal(post["aM"] + post["bAM"] * A_seq).sample(random.PRNGKey(1))
D_sim = dist.Normal(post["a"] + post["bA"] * A_seq + post["bM"] * M_sim).sample(
random.PRNGKey(1)
)
milk = pd.read_csv("../data/milk.csv", sep=";")
d = milk
d.info()
d.head()
<class 'pandas.core.frame.DataFrame'> RangeIndex: 29 entries, 0 to 28 Data columns (total 8 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 clade 29 non-null object 1 species 29 non-null object 2 kcal.per.g 29 non-null float64 3 perc.fat 29 non-null float64 4 perc.protein 29 non-null float64 5 perc.lactose 29 non-null float64 6 mass 29 non-null float64 7 neocortex.perc 17 non-null float64 dtypes: float64(6), object(2) memory usage: 1.9+ KB
clade | species | kcal.per.g | perc.fat | perc.protein | perc.lactose | mass | neocortex.perc | |
---|---|---|---|---|---|---|---|---|
0 | Strepsirrhine | Eulemur fulvus | 0.49 | 16.60 | 15.42 | 67.98 | 1.95 | 55.16 |
1 | Strepsirrhine | E macaco | 0.51 | 19.27 | 16.91 | 63.82 | 2.09 | NaN |
2 | Strepsirrhine | E mongoz | 0.46 | 14.11 | 16.85 | 69.04 | 2.51 | NaN |
3 | Strepsirrhine | E rubriventer | 0.48 | 14.91 | 13.18 | 71.91 | 1.62 | NaN |
4 | Strepsirrhine | Lemur catta | 0.60 | 27.28 | 19.50 | 53.22 | 2.19 | NaN |
d["K"] = d["kcal.per.g"].pipe(lambda x: (x - x.mean()) / x.std())
d["N"] = d["neocortex.perc"].pipe(lambda x: (x - x.mean()) / x.std())
d["M"] = d.mass.map(math.log).pipe(lambda x: (x - x.mean()) / x.std())
def model(N, K):
a = numpyro.sample("a", dist.Normal(0, 1))
bN = numpyro.sample("bN", dist.Normal(0, 1))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = a + bN * N
numpyro.sample("K", dist.Normal(mu, sigma), obs=K)
with numpyro.validation_enabled():
try:
m5_5_draft = AutoLaplaceApproximation(model)
svi = SVI(
model, m5_5_draft, optim.Adam(1), Trace_ELBO(), N=d.N.values, K=d.K.values
)
svi_result = svi.run(random.PRNGKey(0), 1000)
p5_5_draft = svi_result.params
except ValueError as e:
print(str(e))
Normal distribution got invalid loc parameter.
d["neocortex.perc"]
0 55.16 1 NaN 2 NaN 3 NaN 4 NaN 5 64.54 6 64.54 7 67.64 8 NaN 9 68.85 10 58.85 11 61.69 12 60.32 13 NaN 14 NaN 15 69.97 16 NaN 17 70.41 18 NaN 19 73.40 20 NaN 21 67.53 22 NaN 23 71.26 24 72.60 25 NaN 26 70.24 27 76.30 28 75.49 Name: neocortex.perc, dtype: float64
dcc = d.iloc[d[["K", "N", "M"]].dropna(how="any", axis=0).index]
def model(N, K=None):
a = numpyro.sample("a", dist.Normal(0, 1))
bN = numpyro.sample("bN", dist.Normal(0, 1))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = numpyro.deterministic("mu", a + bN * N)
numpyro.sample("K", dist.Normal(mu, sigma), obs=K)
m5_5_draft = AutoLaplaceApproximation(model)
svi = SVI(
model, m5_5_draft, optim.Adam(0.1), Trace_ELBO(), N=dcc.N.values, K=dcc.K.values
)
svi_result = svi.run(random.PRNGKey(0), 1000)
p5_5_draft = svi_result.params
100%|██████████| 1000/1000 [00:00<00:00, 1444.83it/s, init loss: 411.1621, avg. loss [951-1000]: 26.8758]
xseq = jnp.array([-2, 2])
prior_pred = Predictive(model, num_samples=1000)(random.PRNGKey(1), N=xseq)
mu = prior_pred["mu"]
plt.subplot(xlim=xseq, ylim=xseq)
for i in range(50):
plt.plot(xseq, mu[i], "k", alpha=0.3)
def model(N, K=None):
a = numpyro.sample("a", dist.Normal(0, 0.2))
bN = numpyro.sample("bN", dist.Normal(0, 0.5))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = numpyro.deterministic("mu", a + bN * N)
numpyro.sample("K", dist.Normal(mu, sigma), obs=K)
m5_5 = AutoLaplaceApproximation(model)
svi = SVI(model, m5_5, optim.Adam(1), Trace_ELBO(), N=dcc.N.values, K=dcc.K.values)
svi_result = svi.run(random.PRNGKey(0), 1000)
p5_5 = svi_result.params
100%|██████████| 1000/1000 [00:00<00:00, 1399.46it/s, init loss: 414.1050, avg. loss [951-1000]: 24.6918]
post = m5_5.sample_posterior(random.PRNGKey(1), p5_5, sample_shape=(1000,))
print_summary(post, 0.89, False)
mean std median 5.5% 94.5% n_eff r_hat a 0.05 0.16 0.05 -0.21 0.29 931.50 1.00 bN 0.13 0.23 0.13 -0.21 0.53 1111.88 1.00 mu[0] -0.22 0.51 -0.22 -0.99 0.60 911.65 1.00 mu[1] -0.02 0.20 -0.01 -0.36 0.27 906.66 1.00 mu[2] -0.02 0.20 -0.01 -0.36 0.27 906.66 1.00 mu[3] 0.05 0.16 0.05 -0.21 0.29 931.28 1.00 mu[4] 0.08 0.17 0.08 -0.18 0.36 940.50 1.00 mu[5] -0.14 0.38 -0.13 -0.74 0.45 889.84 1.00 mu[6] -0.08 0.28 -0.07 -0.52 0.36 874.50 1.00 mu[7] -0.11 0.33 -0.10 -0.65 0.37 884.63 1.00 mu[8] 0.10 0.18 0.10 -0.18 0.41 964.72 1.00 mu[9] 0.11 0.19 0.11 -0.17 0.43 975.68 1.00 mu[10] 0.17 0.28 0.17 -0.25 0.63 1037.05 1.00 mu[11] 0.05 0.16 0.05 -0.21 0.29 931.71 1.00 mu[12] 0.13 0.21 0.13 -0.18 0.49 996.35 1.00 mu[13] 0.16 0.25 0.15 -0.21 0.58 1024.00 1.00 mu[14] 0.11 0.19 0.10 -0.17 0.43 971.43 1.00 mu[15] 0.24 0.37 0.23 -0.39 0.80 1067.60 1.00 mu[16] 0.22 0.35 0.21 -0.37 0.74 1061.13 1.00 sigma 1.05 0.18 1.03 0.78 1.35 944.03 1.00
xseq = jnp.linspace(start=dcc.N.min() - 0.15, stop=dcc.N.max() + 0.15, num=30)
post = m5_5.sample_posterior(random.PRNGKey(1), p5_5, sample_shape=(1000,))
post.pop("mu")
post_pred = Predictive(m5_5.model, post)(random.PRNGKey(2), N=xseq)
mu = post_pred["mu"]
mu_mean = jnp.mean(mu, 0)
mu_PI = jnp.percentile(mu, q=jnp.array([5.5, 94.5]), axis=0)
az.plot_pair(dcc[["N", "K"]].to_dict(orient="list"))
plt.plot(xseq, mu_mean, "k")
plt.fill_between(xseq, mu_PI[0], mu_PI[1], color="k", alpha=0.2)
plt.show()
def model(M, K=None):
a = numpyro.sample("a", dist.Normal(0, 0.2))
bM = numpyro.sample("bM", dist.Normal(0, 0.5))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = numpyro.deterministic("mu", a + bM * M)
numpyro.sample("K", dist.Normal(mu, sigma), obs=K)
m5_6 = AutoLaplaceApproximation(model)
svi = SVI(model, m5_6, optim.Adam(1), Trace_ELBO(), M=dcc.M.values, K=dcc.K.values)
svi_result = svi.run(random.PRNGKey(0), 1000)
p5_6 = svi_result.params
post = m5_6.sample_posterior(random.PRNGKey(1), p5_6, sample_shape=(1000,))
print_summary(post, 0.89, False)
100%|██████████| 1000/1000 [00:00<00:00, 1457.13it/s, init loss: 756.0300, avg. loss [951-1000]: 23.9327]
mean std median 5.5% 94.5% n_eff r_hat a 0.06 0.16 0.06 -0.20 0.29 931.50 1.00 bM -0.28 0.20 -0.28 -0.61 0.03 1088.44 1.00 mu[0] 0.18 0.18 0.19 -0.11 0.47 944.38 1.00 mu[1] 0.02 0.16 0.02 -0.24 0.26 932.90 1.00 mu[2] 0.02 0.16 0.02 -0.24 0.26 933.51 1.00 mu[3] 0.14 0.17 0.15 -0.11 0.42 958.01 1.00 mu[4] 0.36 0.27 0.36 -0.07 0.78 872.95 1.00 mu[5] 0.65 0.45 0.65 -0.03 1.39 898.89 1.00 mu[6] 0.42 0.31 0.43 -0.07 0.89 878.58 1.00 mu[7] 0.49 0.35 0.50 -0.09 1.00 885.51 1.00 mu[8] 0.22 0.20 0.23 -0.12 0.51 906.05 1.00 mu[9] 0.10 0.16 0.10 -0.17 0.33 940.69 1.00 mu[10] -0.12 0.20 -0.13 -0.40 0.21 989.01 1.00 mu[11] 0.02 0.16 0.02 -0.24 0.26 933.51 1.00 mu[12] -0.30 0.29 -0.30 -0.75 0.18 1050.56 1.00 mu[13] -0.43 0.38 -0.44 -1.07 0.12 1073.41 1.00 mu[14] -0.32 0.31 -0.33 -0.81 0.16 1055.58 1.00 mu[15] -0.29 0.28 -0.29 -0.74 0.16 1047.84 1.00 mu[16] -0.37 0.34 -0.38 -0.95 0.12 1064.72 1.00 sigma 0.99 0.17 0.98 0.72 1.26 957.10 1.00
def model(N, M, K=None):
a = numpyro.sample("a", dist.Normal(0, 0.2))
bN = numpyro.sample("bN", dist.Normal(0, 0.5))
bM = numpyro.sample("bM", dist.Normal(0, 0.5))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = numpyro.deterministic("mu", a + bN * N + bM * M)
numpyro.sample("K", dist.Normal(mu, sigma), obs=K)
m5_7 = AutoLaplaceApproximation(model)
svi = SVI(
model,
m5_7,
optim.Adam(1),
Trace_ELBO(),
N=dcc.N.values,
M=dcc.M.values,
K=dcc.K.values,
)
svi_result = svi.run(random.PRNGKey(0), 1000)
p5_7 = svi_result.params
post = m5_7.sample_posterior(random.PRNGKey(1), p5_7, sample_shape=(1000,))
print_summary(post, 0.89, False)
100%|██████████| 1000/1000 [00:00<00:00, 1324.26it/s, init loss: 136.3944, avg. loss [951-1000]: 21.6292]
mean std median 5.5% 94.5% n_eff r_hat a 0.06 0.13 0.06 -0.15 0.26 1049.96 1.00 bM -0.68 0.23 -0.68 -1.06 -0.32 837.54 1.00 bN 0.65 0.25 0.66 0.25 1.06 885.39 1.00 mu[0] -0.99 0.47 -1.00 -1.72 -0.22 953.05 1.00 mu[1] -0.36 0.20 -0.35 -0.64 -0.03 1076.34 1.00 mu[2] -0.37 0.20 -0.36 -0.65 -0.03 1068.21 1.00 mu[3] 0.28 0.16 0.28 0.03 0.52 967.73 1.00 mu[4] 0.94 0.33 0.94 0.43 1.44 876.35 1.00 mu[5] 0.53 0.39 0.53 -0.10 1.11 1015.67 1.00 mu[6] 0.30 0.26 0.30 -0.10 0.73 1041.23 1.00 mu[7] 0.30 0.30 0.30 -0.16 0.76 1039.12 1.00 mu[8] 0.73 0.27 0.73 0.35 1.17 867.10 1.00 mu[9] 0.48 0.21 0.49 0.16 0.80 872.50 1.00 mu[10] 0.27 0.23 0.28 -0.06 0.66 895.61 1.00 mu[11] -0.04 0.14 -0.04 -0.25 0.17 992.18 1.00 mu[12] -0.39 0.25 -0.39 -0.81 -0.00 864.74 1.00 mu[13] -0.56 0.32 -0.57 -1.03 0.01 871.33 1.00 mu[14] -0.55 0.27 -0.55 -0.98 -0.10 858.97 1.00 mu[15] 0.20 0.30 0.20 -0.29 0.66 916.65 1.00 mu[16] -0.10 0.30 -0.10 -0.59 0.37 895.92 1.00 sigma 0.77 0.14 0.77 0.55 0.97 1029.58 1.00
coeftab = {
"m5.5": m5_5.sample_posterior(random.PRNGKey(1), p5_5, sample_shape=(1, 1000)),
"m5.6": m5_6.sample_posterior(random.PRNGKey(2), p5_6, sample_shape=(1, 1000)),
"m5.7": m5_7.sample_posterior(random.PRNGKey(3), p5_7, sample_shape=(1, 1000)),
}
az.plot_forest(
list(coeftab.values()),
model_names=list(coeftab.keys()),
var_names=["bM", "bN"],
hdi_prob=0.89,
)
plt.show()
xseq = jnp.linspace(start=dcc.N.min() - 0.15, stop=dcc.N.max() + 0.15, num=30)
post = m5_7.sample_posterior(random.PRNGKey(1), p5_7, sample_shape=(1000,))
post.pop("mu")
post_pred = Predictive(m5_7.model, post)(random.PRNGKey(2), M=0, N=xseq)
mu = post_pred["mu"]
mu_mean = jnp.mean(mu, 0)
mu_PI = jnp.percentile(mu, q=jnp.array([5.5, 94.5]), axis=0)
plt.subplot(xlim=(dcc.M.min(), dcc.M.max()), ylim=(dcc.K.min(), dcc.K.max()))
plt.plot(xseq, mu_mean, "k")
plt.fill_between(xseq, mu_PI[0], mu_PI[1], color="k", alpha=0.2)
plt.show()
# M -> K <- N
# M -> N
n = 100
M = dist.Normal().sample(random.PRNGKey(0), (n,))
N = dist.Normal(M).sample(random.PRNGKey(1))
K = dist.Normal(N - M).sample(random.PRNGKey(2))
d_sim = pd.DataFrame({"K": K, "N": N, "M": M})
# M -> K <- N
# N -> M
n = 100
N = dist.Normal().sample(random.PRNGKey(0), (n,))
M = dist.Normal(N).sample(random.PRNGKey(1))
K = dist.Normal(N - M).sample(random.PRNGKey(2))
d_sim2 = pd.DataFrame({"K": K, "N": N, "M": M})
# M -> K <- N
# M <- U -> N
n = 100
N = dist.Normal().sample(random.PRNGKey(3), (n,))
M = dist.Normal(M).sample(random.PRNGKey(4))
K = dist.Normal(N - M).sample(random.PRNGKey(5))
d_sim3 = pd.DataFrame({"K": K, "N": N, "M": M})
dag5_7 = nx.DiGraph()
dag5_7.add_edges_from([("M", "K"), ("N", "K"), ("M", "N")])
coordinates = {"M": (0, 0.5), "K": (1, 1), "N": (2, 0.5)}
MElist = []
for i in range(2):
for j in range(2):
for k in range(2):
new_dag = nx.DiGraph()
new_dag.add_edges_from(
[edge[::-1] if flip else edge for edge, flip in zip(dag5_7.edges, (i, j, k))]
)
if not list(nx.simple_cycles(new_dag)):
MElist.append(new_dag)
Howell1 = pd.read_csv("../data/Howell1.csv", sep=";")
d = Howell1
d.info()
d.head()
<class 'pandas.core.frame.DataFrame'> RangeIndex: 544 entries, 0 to 543 Data columns (total 4 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 height 544 non-null float64 1 weight 544 non-null float64 2 age 544 non-null float64 3 male 544 non-null int64 dtypes: float64(3), int64(1) memory usage: 17.1 KB
height | weight | age | male | |
---|---|---|---|---|
0 | 151.765 | 47.825606 | 63.0 | 1 |
1 | 139.700 | 36.485807 | 63.0 | 0 |
2 | 136.525 | 31.864838 | 65.0 | 0 |
3 | 156.845 | 53.041915 | 41.0 | 1 |
4 | 145.415 | 41.276872 | 51.0 | 0 |
mu_female = dist.Normal(178, 20).sample(random.PRNGKey(0), (int(1e4),))
diff = dist.Normal(0, 10).sample(random.PRNGKey(1), (int(1e4),))
mu_male = dist.Normal(178, 20).sample(random.PRNGKey(2), (int(1e4),)) + diff
print_summary({"mu_female": mu_female, "mu_male": mu_male}, 0.89, False)
mean std median 5.5% 94.5% n_eff r_hat mu_female 178.21 20.22 178.24 147.19 211.84 9943.61 1.00 mu_male 178.10 22.36 178.51 142.35 213.41 10190.57 1.00
d["sex"] = jnp.where(d.male.values == 1, 1, 0)
d.sex
0 1 1 0 2 0 3 1 4 0 .. 539 1 540 1 541 0 542 1 543 1 Name: sex, Length: 544, dtype: int32
def model(sex, height):
a = numpyro.sample("a", dist.Normal(178, 20).expand([len(set(sex))]))
sigma = numpyro.sample("sigma", dist.Uniform(0, 50))
mu = a[sex]
numpyro.sample("height", dist.Normal(mu, sigma), obs=height)
m5_8 = AutoLaplaceApproximation(model)
svi = SVI(
model, m5_8, optim.Adam(1), Trace_ELBO(), sex=d.sex.values, height=d.height.values
)
svi_result = svi.run(random.PRNGKey(0), 2000)
p5_8 = svi_result.params
post = m5_8.sample_posterior(random.PRNGKey(1), p5_8, sample_shape=(1000,))
print_summary(post, 0.89, False)
100%|██████████| 2000/2000 [00:01<00:00, 1815.39it/s, init loss: 5607.9023, avg. loss [1901-2000]: 2558.3149]
mean std median 5.5% 94.5% n_eff r_hat a[0] 135.02 1.63 135.07 132.32 137.46 931.50 1.00 a[1] 142.56 1.73 142.54 140.02 145.51 1111.51 1.00 sigma 27.32 0.84 27.32 26.03 28.71 951.62 1.00
post = m5_8.sample_posterior(random.PRNGKey(1), p5_8, sample_shape=(1000,))
post["diff_fm"] = post["a"][:, 0] - post["a"][:, 1]
print_summary(post, 0.89, False)
mean std median 5.5% 94.5% n_eff r_hat a[0] 135.02 1.63 135.07 132.32 137.46 931.50 1.00 a[1] 142.56 1.73 142.54 140.02 145.51 1111.51 1.00 diff_fm -7.54 2.38 -7.47 -11.77 -4.32 876.56 1.00 sigma 27.32 0.84 27.32 26.03 28.71 951.62 1.00
milk = pd.read_csv("../data/milk.csv", sep=";")
d = milk
d.clade.unique()
array(['Strepsirrhine', 'New World Monkey', 'Old World Monkey', 'Ape'], dtype=object)
d["clade_id"] = d.clade.astype("category").cat.codes
d["K"] = d["kcal.per.g"].pipe(lambda x: (x - x.mean()) / x.std())
def model(clade_id, K):
a = numpyro.sample("a", dist.Normal(0, 0.5).expand([len(set(clade_id))]))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = a[clade_id]
numpyro.sample("height", dist.Normal(mu, sigma), obs=K)
m5_9 = AutoLaplaceApproximation(model)
svi = SVI(
model, m5_9, optim.Adam(1), Trace_ELBO(), clade_id=d.clade_id.values, K=d.K.values
)
svi_result = svi.run(random.PRNGKey(0), 1000)
p5_9 = svi_result.params
post = m5_9.sample_posterior(random.PRNGKey(1), p5_9, sample_shape=(1000,))
labels = ["a[" + str(i) + "]:" + s for i, s in enumerate(sorted(d.clade.unique()))]
az.plot_forest({"a": post["a"][None, ...]}, hdi_prob=0.89)
plt.gca().set(yticklabels=labels[::-1], xlabel="expected kcal (std)")
plt.show()
100%|██████████| 1000/1000 [00:00<00:00, 1407.72it/s, init loss: 94.6847, avg. loss [951-1000]: 35.5646]
key = random.PRNGKey(63)
d["house"] = random.choice(key, jnp.repeat(jnp.arange(4), 8), d.shape[:1], False)
def model(clade_id, house, K):
a = numpyro.sample("a", dist.Normal(0, 0.5).expand([len(set(clade_id))]))
h = numpyro.sample("h", dist.Normal(0, 0.5).expand([len(set(house))]))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = a[clade_id] + h[house]
numpyro.sample("height", dist.Normal(mu, sigma), obs=K)
m5_10 = AutoLaplaceApproximation(model)
svi = SVI(
model,
m5_10,
optim.Adam(1),
Trace_ELBO(),
clade_id=d.clade_id.values,
house=d.house.values,
K=d.K.values,
)
svi_result = svi.run(random.PRNGKey(0), 1000)
p5_10 = svi_result.params
100%|██████████| 1000/1000 [00:00<00:00, 1231.03it/s, init loss: 491.4240, avg. loss [951-1000]: 35.8243]