#!/usr/bin/env python # coding: utf-8 # # Chapter 5. The Many Variables & The Spurious Waffles # In[ ]: get_ipython().system('pip install -q numpyro arviz daft networkx') # In[1]: import collections import itertools import math import os import arviz as az import daft import matplotlib.pyplot as plt import networkx as nx import pandas as pd import jax.numpy as jnp from jax import random import numpyro import numpyro.distributions as dist import numpyro.optim as optim from numpyro.diagnostics import print_summary from numpyro.infer import Predictive, SVI, Trace_ELBO from numpyro.infer.autoguide import AutoLaplaceApproximation if "SVG" in os.environ: get_ipython().run_line_magic('config', 'InlineBackend.figure_formats = ["svg"]') az.style.use("arviz-darkgrid") numpyro.set_platform("cpu") # ### Code 5.1 # In[2]: # load data and copy WaffleDivorce = pd.read_csv("../data/WaffleDivorce.csv", sep=";") d = WaffleDivorce # standardize variables d["A"] = d.MedianAgeMarriage.pipe(lambda x: (x - x.mean()) / x.std()) d["D"] = d.Divorce.pipe(lambda x: (x - x.mean()) / x.std()) # ### Code 5.2 # In[3]: d.MedianAgeMarriage.std() # ### Code 5.3 # In[4]: def model(A, D=None): a = numpyro.sample("a", dist.Normal(0, 0.2)) bA = numpyro.sample("bA", dist.Normal(0, 0.5)) sigma = numpyro.sample("sigma", dist.Exponential(1)) mu = numpyro.deterministic("mu", a + bA * A) numpyro.sample("D", dist.Normal(mu, sigma), obs=D) m5_1 = AutoLaplaceApproximation(model) svi = SVI(model, m5_1, optim.Adam(1), Trace_ELBO(), A=d.A.values, D=d.D.values) svi_result = svi.run(random.PRNGKey(0), 1000) p5_1 = svi_result.params # ### Code 5.4 # In[5]: predictive = Predictive(m5_1.model, num_samples=1000, return_sites=["mu"]) prior_pred = predictive(random.PRNGKey(10), A=jnp.array([-2, 2])) mu = prior_pred["mu"] plt.subplot(xlim=(-2, 2), ylim=(-2, 2)) for i in range(20): plt.plot([-2, 2], mu[i], "k", alpha=0.4) # ### Code 5.5 # In[6]: # compute percentile interval of mean A_seq = jnp.linspace(start=-3, stop=3.2, num=30) post = m5_1.sample_posterior(random.PRNGKey(1), p5_1, sample_shape=(1000,)) post.pop("mu") post_pred = Predictive(m5_1.model, post)(random.PRNGKey(2), A=A_seq) mu = post_pred["mu"] mu_mean = jnp.mean(mu, 0) mu_PI = jnp.percentile(mu, q=jnp.array([5.5, 94.5]), axis=0) # plot it all az.plot_pair(d[["D", "A"]].to_dict(orient="list")) plt.plot(A_seq, mu_mean, "k") plt.fill_between(A_seq, mu_PI[0], mu_PI[1], color="k", alpha=0.2) plt.show() # ### Code 5.6 # In[7]: d["M"] = d.Marriage.pipe(lambda x: (x - x.mean()) / x.std()) def model(M, D=None): a = numpyro.sample("a", dist.Normal(0, 0.2)) bM = numpyro.sample("bM", dist.Normal(0, 0.5)) sigma = numpyro.sample("sigma", dist.Exponential(1)) mu = a + bM * M numpyro.sample("D", dist.Normal(mu, sigma), obs=D) m5_2 = AutoLaplaceApproximation(model) svi = SVI(model, m5_2, optim.Adam(1), Trace_ELBO(), M=d.M.values, D=d.D.values) svi_result = svi.run(random.PRNGKey(0), 1000) p5_2 = svi_result.params # ### Code 5.7 # In[8]: dag5_1 = nx.DiGraph() dag5_1.add_edges_from([("A", "D"), ("A", "M"), ("M", "D")]) pgm = daft.PGM() coordinates = {"A": (0, 0), "D": (1, 1), "M": (2, 0)} for node in dag5_1.nodes: pgm.add_node(node, node, *coordinates[node]) for edge in dag5_1.edges: pgm.add_edge(*edge) with plt.rc_context({"figure.constrained_layout.use": False}): pgm.render() plt.gca().invert_yaxis() plt.show() # ### Code 5.8 # In[9]: DMA_dag2 = nx.DiGraph() DMA_dag2.add_edges_from([("A", "D"), ("A", "M")]) conditional_independencies = collections.defaultdict(list) for edge in itertools.combinations(sorted(DMA_dag2.nodes), 2): remaining = sorted(set(DMA_dag2.nodes) - set(edge)) for size in range(len(remaining) + 1): for subset in itertools.combinations(remaining, size): if any(cond.issubset(set(subset)) for cond in conditional_independencies[edge]): continue if nx.d_separated(DMA_dag2, {edge[0]}, {edge[1]}, set(subset)): conditional_independencies[edge].append(set(subset)) print(f"{edge[0]} _||_ {edge[1]}" + (f" | {' '.join(subset)}" if subset else "")) # ### Code 5.9 # In[10]: DMA_dag1 = nx.DiGraph() DMA_dag1.add_edges_from([("A", "D"), ("A", "M"), ("M", "D")]) conditional_independencies = collections.defaultdict(list) for edge in itertools.combinations(sorted(DMA_dag1.nodes), 2): remaining = sorted(set(DMA_dag1.nodes) - set(edge)) for size in range(len(remaining) + 1): for subset in itertools.combinations(remaining, size): if any(cond.issubset(set(subset)) for cond in conditional_independencies[edge]): continue if nx.d_separated(DMA_dag1, {edge[0]}, {edge[1]}, set(subset)): conditional_independencies[edge].append(set(subset)) print(f"{edge[0]} _||_ {edge[1]}" + (f" | {' '.join(subset)}" if subset else "")) # ### Code 5.10 # In[11]: def model(M, A, D=None): a = numpyro.sample("a", dist.Normal(0, 0.2)) bM = numpyro.sample("bM", dist.Normal(0, 0.5)) bA = numpyro.sample("bA", dist.Normal(0, 0.5)) sigma = numpyro.sample("sigma", dist.Exponential(1)) mu = numpyro.deterministic("mu", a + bM * M + bA * A) numpyro.sample("D", dist.Normal(mu, sigma), obs=D) m5_3 = AutoLaplaceApproximation(model) svi = SVI( model, m5_3, optim.Adam(1), Trace_ELBO(), M=d.M.values, A=d.A.values, D=d.D.values ) svi_result = svi.run(random.PRNGKey(0), 1000) p5_3 = svi_result.params post = m5_3.sample_posterior(random.PRNGKey(1), p5_3, sample_shape=(1000,)) print_summary(post, 0.89, False) # ### Code 5.11 # In[12]: coeftab = { "m5.1": m5_1.sample_posterior(random.PRNGKey(1), p5_1, sample_shape=(1, 1000)), "m5.2": m5_2.sample_posterior(random.PRNGKey(2), p5_2, sample_shape=(1, 1000)), "m5.3": m5_3.sample_posterior(random.PRNGKey(3), p5_3, sample_shape=(1, 1000)), } az.plot_forest( list(coeftab.values()), model_names=list(coeftab.keys()), var_names=["bA", "bM"], hdi_prob=0.89, ) plt.show() # ### Code 5.12 # In[13]: N = 50 # number of simulated States age = dist.Normal().sample(random.PRNGKey(0), sample_shape=(N,)) # sim A mar = dist.Normal(age).sample(random.PRNGKey(1)) # sim A -> M div = dist.Normal(age).sample(random.PRNGKey(2)) # sim A -> D # ### Code 5.13 # In[14]: def model(A, M=None): a = numpyro.sample("a", dist.Normal(0, 0.2)) bAM = numpyro.sample("bA", dist.Normal(0, 0.5)) sigma = numpyro.sample("sigma", dist.Exponential(1)) mu = numpyro.deterministic("mu", a + bAM * A) numpyro.sample("M", dist.Normal(mu, sigma), obs=M) m5_4 = AutoLaplaceApproximation(model) svi = SVI(model, m5_4, optim.Adam(0.1), Trace_ELBO(), A=d.A.values, M=d.M.values) svi_result = svi.run(random.PRNGKey(0), 1000) p5_4 = svi_result.params # ### Code 5.14 # In[15]: post = m5_4.sample_posterior(random.PRNGKey(1), p5_4, sample_shape=(1000,)) post.pop("mu") post_pred = Predictive(m5_4.model, post)(random.PRNGKey(2), A=d.A.values) mu = post_pred["mu"] mu_mean = jnp.mean(mu, 0) mu_resid = d.M.values - mu_mean # ### Code 5.15 # In[16]: # call predictive without specifying new data # so it uses original data post = m5_3.sample_posterior(random.PRNGKey(1), p5_3, sample_shape=(int(1e4),)) post.pop("mu") post_pred = Predictive(m5_3.model, post)(random.PRNGKey(2), M=d.M.values, A=d.A.values) mu = post_pred["mu"] # summarize samples across cases mu_mean = jnp.mean(mu, 0) mu_PI = jnp.percentile(mu, q=jnp.array([5.5, 94.5]), axis=0) # simulate observations # again no new data, so uses original data D_sim = post_pred["D"] D_PI = jnp.percentile(D_sim, q=jnp.array([5.5, 94.5]), axis=0) # ### Code 5.16 # In[17]: ax = plt.subplot( ylim=(float(mu_PI.min()), float(mu_PI.max())), xlabel="Observed divorce", ylabel="Predicted divorce", ) plt.plot(d.D, mu_mean, "o") x = jnp.linspace(mu_PI.min(), mu_PI.max(), 101) plt.plot(x, x, "--") for i in range(d.shape[0]): plt.plot([d.D[i]] * 2, mu_PI[:, i], "b") fig = plt.gcf() # ### Code 5.17 # In[18]: for i in range(d.shape[0]): if d.Loc[i] in ["ID", "UT", "RI", "ME"]: ax.annotate( d.Loc[i], (d.D[i], mu_mean[i]), xytext=(-25, -5), textcoords="offset pixels" ) fig # ### Code 5.18 # In[19]: N = 100 # number of cases # x_real as Gaussian with mean 0 and stddev 1 x_real = dist.Normal().sample(random.PRNGKey(0), (N,)) # x_spur as Gaussian with mean=x_real x_spur = dist.Normal(x_real).sample(random.PRNGKey(1)) # y as Gaussian with mean=x_real y = dist.Normal(x_real).sample(random.PRNGKey(2)) # bind all together in data frame d = pd.DataFrame({"y": y, "x_real": x_real, "x_spur": x_spur}) # ### Code 5.19 # In[20]: WaffleDivorce = pd.read_csv("../data/WaffleDivorce.csv", sep=";") d = WaffleDivorce d["A"] = d.MedianAgeMarriage.pipe(lambda x: (x - x.mean()) / x.std()) d["D"] = d.Divorce.pipe(lambda x: (x - x.mean()) / x.std()) d["M"] = d.Marriage.pipe(lambda x: (x - x.mean()) / x.std()) def model(A, M=None, D=None): # A -> M aM = numpyro.sample("aM", dist.Normal(0, 0.2)) bAM = numpyro.sample("bAM", dist.Normal(0, 0.5)) sigma_M = numpyro.sample("sigma_M", dist.Exponential(1)) mu_M = aM + bAM * A M = numpyro.sample("M", dist.Normal(mu_M, sigma_M), obs=M) # A -> D <- M a = numpyro.sample("a", dist.Normal(0, 0.2)) bM = numpyro.sample("bM", dist.Normal(0, 0.5)) bA = numpyro.sample("bA", dist.Normal(0, 0.5)) sigma = numpyro.sample("sigma", dist.Exponential(1)) mu = a + bM * M + bA * A numpyro.sample("D", dist.Normal(mu, sigma), obs=D) m5_3_A = AutoLaplaceApproximation(model) svi = SVI( model, m5_3_A, optim.Adam(0.1), Trace_ELBO(), A=d.A.values, M=d.M.values, D=d.D.values, ) svi_result = svi.run(random.PRNGKey(0), 1000) p5_3_A = svi_result.params # ### Code 5.20 # In[21]: A_seq = jnp.linspace(-2, 2, num=30) # ### Code 5.21 # In[22]: # prep data sim_dat = dict(A=A_seq) # simulate M and then D, using A_seq post = m5_3_A.sample_posterior(random.PRNGKey(1), p5_3_A, sample_shape=(1000,)) s = Predictive(m5_3_A.model, post)(random.PRNGKey(2), **sim_dat) # ### Code 5.22 # In[23]: plt.plot(sim_dat["A"], jnp.mean(s["D"], 0)) plt.gca().set(ylim=(-2, 2), xlabel="manipulated A", ylabel="counterfactual D") plt.fill_between( sim_dat["A"], *jnp.percentile(s["D"], q=jnp.array([5.5, 94.5]), axis=0), color="k", alpha=0.2 ) plt.title("Total counterfactual effect of A on D") plt.show() # ### Code 5.23 # In[24]: # new data frame, standardized to mean 26.1 and stddev 1.24 sim2_dat = dict(A=(jnp.array([20, 30]) - 26.1) / 1.24) s2 = Predictive(m5_3_A.model, post, return_sites=["M", "D"])( random.PRNGKey(2), **sim2_dat ) jnp.mean(s2["D"][:, 1] - s2["D"][:, 0]) # ### Code 5.24 # In[25]: sim_dat = dict(M=jnp.linspace(-2, 2, num=30), A=0) s = Predictive(m5_3_A.model, post)(random.PRNGKey(2), **sim_dat)["D"] plt.plot(sim_dat["M"], jnp.mean(s, 0)) plt.gca().set(ylim=(-2, 2), xlabel="manipulated A", ylabel="counterfactual D") plt.fill_between( sim_dat["M"], *jnp.percentile(s, q=jnp.array([5.5, 94.5]), axis=0), color="k", alpha=0.2 ) plt.title("Total counterfactual effect of M on D") plt.show() # ### Code 5.25 # In[26]: A_seq = jnp.linspace(-2, 2, num=30) # ### Code 5.26 # In[27]: post = m5_3_A.sample_posterior(random.PRNGKey(1), p5_3_A, sample_shape=(1000,)) post = {k: v[..., None] for k, v in post.items()} M_sim = dist.Normal(post["aM"] + post["bAM"] * A_seq).sample(random.PRNGKey(1)) # ### Code 5.27 # In[28]: D_sim = dist.Normal(post["a"] + post["bA"] * A_seq + post["bM"] * M_sim).sample( random.PRNGKey(1) ) # ### Code 5.28 # In[29]: milk = pd.read_csv("../data/milk.csv", sep=";") d = milk d.info() d.head() # ### Code 5.29 # In[30]: d["K"] = d["kcal.per.g"].pipe(lambda x: (x - x.mean()) / x.std()) d["N"] = d["neocortex.perc"].pipe(lambda x: (x - x.mean()) / x.std()) d["M"] = d.mass.map(math.log).pipe(lambda x: (x - x.mean()) / x.std()) # ### Code 5.30 # In[31]: def model(N, K): a = numpyro.sample("a", dist.Normal(0, 1)) bN = numpyro.sample("bN", dist.Normal(0, 1)) sigma = numpyro.sample("sigma", dist.Exponential(1)) mu = a + bN * N numpyro.sample("K", dist.Normal(mu, sigma), obs=K) with numpyro.validation_enabled(): try: m5_5_draft = AutoLaplaceApproximation(model) svi = SVI( model, m5_5_draft, optim.Adam(1), Trace_ELBO(), N=d.N.values, K=d.K.values ) svi_result = svi.run(random.PRNGKey(0), 1000) p5_5_draft = svi_result.params except ValueError as e: print(str(e)) # ### Code 5.31 # In[32]: d["neocortex.perc"] # ### Code 5.32 # In[33]: dcc = d.iloc[d[["K", "N", "M"]].dropna(how="any", axis=0).index] # ### Code 5.33 # In[34]: def model(N, K=None): a = numpyro.sample("a", dist.Normal(0, 1)) bN = numpyro.sample("bN", dist.Normal(0, 1)) sigma = numpyro.sample("sigma", dist.Exponential(1)) mu = numpyro.deterministic("mu", a + bN * N) numpyro.sample("K", dist.Normal(mu, sigma), obs=K) m5_5_draft = AutoLaplaceApproximation(model) svi = SVI( model, m5_5_draft, optim.Adam(0.1), Trace_ELBO(), N=dcc.N.values, K=dcc.K.values ) svi_result = svi.run(random.PRNGKey(0), 1000) p5_5_draft = svi_result.params # ### Code 5.34 # In[35]: xseq = jnp.array([-2, 2]) prior_pred = Predictive(model, num_samples=1000)(random.PRNGKey(1), N=xseq) mu = prior_pred["mu"] plt.subplot(xlim=xseq, ylim=xseq) for i in range(50): plt.plot(xseq, mu[i], "k", alpha=0.3) # ### Code 5.35 # In[36]: def model(N, K=None): a = numpyro.sample("a", dist.Normal(0, 0.2)) bN = numpyro.sample("bN", dist.Normal(0, 0.5)) sigma = numpyro.sample("sigma", dist.Exponential(1)) mu = numpyro.deterministic("mu", a + bN * N) numpyro.sample("K", dist.Normal(mu, sigma), obs=K) m5_5 = AutoLaplaceApproximation(model) svi = SVI(model, m5_5, optim.Adam(1), Trace_ELBO(), N=dcc.N.values, K=dcc.K.values) svi_result = svi.run(random.PRNGKey(0), 1000) p5_5 = svi_result.params # ### Code 5.36 # In[37]: post = m5_5.sample_posterior(random.PRNGKey(1), p5_5, sample_shape=(1000,)) print_summary(post, 0.89, False) # ### Code 5.37 # In[38]: xseq = jnp.linspace(start=dcc.N.min() - 0.15, stop=dcc.N.max() + 0.15, num=30) post = m5_5.sample_posterior(random.PRNGKey(1), p5_5, sample_shape=(1000,)) post.pop("mu") post_pred = Predictive(m5_5.model, post)(random.PRNGKey(2), N=xseq) mu = post_pred["mu"] mu_mean = jnp.mean(mu, 0) mu_PI = jnp.percentile(mu, q=jnp.array([5.5, 94.5]), axis=0) az.plot_pair(dcc[["N", "K"]].to_dict(orient="list")) plt.plot(xseq, mu_mean, "k") plt.fill_between(xseq, mu_PI[0], mu_PI[1], color="k", alpha=0.2) plt.show() # ### Code 5.38 # In[39]: def model(M, K=None): a = numpyro.sample("a", dist.Normal(0, 0.2)) bM = numpyro.sample("bM", dist.Normal(0, 0.5)) sigma = numpyro.sample("sigma", dist.Exponential(1)) mu = numpyro.deterministic("mu", a + bM * M) numpyro.sample("K", dist.Normal(mu, sigma), obs=K) m5_6 = AutoLaplaceApproximation(model) svi = SVI(model, m5_6, optim.Adam(1), Trace_ELBO(), M=dcc.M.values, K=dcc.K.values) svi_result = svi.run(random.PRNGKey(0), 1000) p5_6 = svi_result.params post = m5_6.sample_posterior(random.PRNGKey(1), p5_6, sample_shape=(1000,)) print_summary(post, 0.89, False) # ### Code 5.39 # In[40]: def model(N, M, K=None): a = numpyro.sample("a", dist.Normal(0, 0.2)) bN = numpyro.sample("bN", dist.Normal(0, 0.5)) bM = numpyro.sample("bM", dist.Normal(0, 0.5)) sigma = numpyro.sample("sigma", dist.Exponential(1)) mu = numpyro.deterministic("mu", a + bN * N + bM * M) numpyro.sample("K", dist.Normal(mu, sigma), obs=K) m5_7 = AutoLaplaceApproximation(model) svi = SVI( model, m5_7, optim.Adam(1), Trace_ELBO(), N=dcc.N.values, M=dcc.M.values, K=dcc.K.values, ) svi_result = svi.run(random.PRNGKey(0), 1000) p5_7 = svi_result.params post = m5_7.sample_posterior(random.PRNGKey(1), p5_7, sample_shape=(1000,)) print_summary(post, 0.89, False) # ### Code 5.40 # In[41]: coeftab = { "m5.5": m5_5.sample_posterior(random.PRNGKey(1), p5_5, sample_shape=(1, 1000)), "m5.6": m5_6.sample_posterior(random.PRNGKey(2), p5_6, sample_shape=(1, 1000)), "m5.7": m5_7.sample_posterior(random.PRNGKey(3), p5_7, sample_shape=(1, 1000)), } az.plot_forest( list(coeftab.values()), model_names=list(coeftab.keys()), var_names=["bM", "bN"], hdi_prob=0.89, ) plt.show() # ### Code 5.41 # In[42]: xseq = jnp.linspace(start=dcc.N.min() - 0.15, stop=dcc.N.max() + 0.15, num=30) post = m5_7.sample_posterior(random.PRNGKey(1), p5_7, sample_shape=(1000,)) post.pop("mu") post_pred = Predictive(m5_7.model, post)(random.PRNGKey(2), M=0, N=xseq) mu = post_pred["mu"] mu_mean = jnp.mean(mu, 0) mu_PI = jnp.percentile(mu, q=jnp.array([5.5, 94.5]), axis=0) plt.subplot(xlim=(dcc.M.min(), dcc.M.max()), ylim=(dcc.K.min(), dcc.K.max())) plt.plot(xseq, mu_mean, "k") plt.fill_between(xseq, mu_PI[0], mu_PI[1], color="k", alpha=0.2) plt.show() # ### Code 5.42 # In[43]: # M -> K <- N # M -> N n = 100 M = dist.Normal().sample(random.PRNGKey(0), (n,)) N = dist.Normal(M).sample(random.PRNGKey(1)) K = dist.Normal(N - M).sample(random.PRNGKey(2)) d_sim = pd.DataFrame({"K": K, "N": N, "M": M}) # ### Code 5.43 # In[44]: # M -> K <- N # N -> M n = 100 N = dist.Normal().sample(random.PRNGKey(0), (n,)) M = dist.Normal(N).sample(random.PRNGKey(1)) K = dist.Normal(N - M).sample(random.PRNGKey(2)) d_sim2 = pd.DataFrame({"K": K, "N": N, "M": M}) # M -> K <- N # M <- U -> N n = 100 N = dist.Normal().sample(random.PRNGKey(3), (n,)) M = dist.Normal(M).sample(random.PRNGKey(4)) K = dist.Normal(N - M).sample(random.PRNGKey(5)) d_sim3 = pd.DataFrame({"K": K, "N": N, "M": M}) # ### Code 5.44 # In[45]: dag5_7 = nx.DiGraph() dag5_7.add_edges_from([("M", "K"), ("N", "K"), ("M", "N")]) coordinates = {"M": (0, 0.5), "K": (1, 1), "N": (2, 0.5)} MElist = [] for i in range(2): for j in range(2): for k in range(2): new_dag = nx.DiGraph() new_dag.add_edges_from( [edge[::-1] if flip else edge for edge, flip in zip(dag5_7.edges, (i, j, k))] ) if not list(nx.simple_cycles(new_dag)): MElist.append(new_dag) # ### Code 5.45 # In[46]: Howell1 = pd.read_csv("../data/Howell1.csv", sep=";") d = Howell1 d.info() d.head() # ### Code 5.46 # In[47]: mu_female = dist.Normal(178, 20).sample(random.PRNGKey(0), (int(1e4),)) diff = dist.Normal(0, 10).sample(random.PRNGKey(1), (int(1e4),)) mu_male = dist.Normal(178, 20).sample(random.PRNGKey(2), (int(1e4),)) + diff print_summary({"mu_female": mu_female, "mu_male": mu_male}, 0.89, False) # ### Code 5.47 # In[48]: d["sex"] = jnp.where(d.male.values == 1, 1, 0) d.sex # ### Code 5.48 # In[49]: def model(sex, height): a = numpyro.sample("a", dist.Normal(178, 20).expand([len(set(sex))])) sigma = numpyro.sample("sigma", dist.Uniform(0, 50)) mu = a[sex] numpyro.sample("height", dist.Normal(mu, sigma), obs=height) m5_8 = AutoLaplaceApproximation(model) svi = SVI( model, m5_8, optim.Adam(1), Trace_ELBO(), sex=d.sex.values, height=d.height.values ) svi_result = svi.run(random.PRNGKey(0), 2000) p5_8 = svi_result.params post = m5_8.sample_posterior(random.PRNGKey(1), p5_8, sample_shape=(1000,)) print_summary(post, 0.89, False) # ### Code 5.49 # In[50]: post = m5_8.sample_posterior(random.PRNGKey(1), p5_8, sample_shape=(1000,)) post["diff_fm"] = post["a"][:, 0] - post["a"][:, 1] print_summary(post, 0.89, False) # ### Code 5.50 # In[51]: milk = pd.read_csv("../data/milk.csv", sep=";") d = milk d.clade.unique() # ### Code 5.51 # In[52]: d["clade_id"] = d.clade.astype("category").cat.codes # ### Code 5.52 # In[53]: d["K"] = d["kcal.per.g"].pipe(lambda x: (x - x.mean()) / x.std()) def model(clade_id, K): a = numpyro.sample("a", dist.Normal(0, 0.5).expand([len(set(clade_id))])) sigma = numpyro.sample("sigma", dist.Exponential(1)) mu = a[clade_id] numpyro.sample("height", dist.Normal(mu, sigma), obs=K) m5_9 = AutoLaplaceApproximation(model) svi = SVI( model, m5_9, optim.Adam(1), Trace_ELBO(), clade_id=d.clade_id.values, K=d.K.values ) svi_result = svi.run(random.PRNGKey(0), 1000) p5_9 = svi_result.params post = m5_9.sample_posterior(random.PRNGKey(1), p5_9, sample_shape=(1000,)) labels = ["a[" + str(i) + "]:" + s for i, s in enumerate(sorted(d.clade.unique()))] az.plot_forest({"a": post["a"][None, ...]}, hdi_prob=0.89) plt.gca().set(yticklabels=labels[::-1], xlabel="expected kcal (std)") plt.show() # ### Code 5.53 # In[54]: key = random.PRNGKey(63) d["house"] = random.choice(key, jnp.repeat(jnp.arange(4), 8), d.shape[:1], False) # ### Code 5.54 # In[55]: def model(clade_id, house, K): a = numpyro.sample("a", dist.Normal(0, 0.5).expand([len(set(clade_id))])) h = numpyro.sample("h", dist.Normal(0, 0.5).expand([len(set(house))])) sigma = numpyro.sample("sigma", dist.Exponential(1)) mu = a[clade_id] + h[house] numpyro.sample("height", dist.Normal(mu, sigma), obs=K) m5_10 = AutoLaplaceApproximation(model) svi = SVI( model, m5_10, optim.Adam(1), Trace_ELBO(), clade_id=d.clade_id.values, house=d.house.values, K=d.K.values, ) svi_result = svi.run(random.PRNGKey(0), 1000) p5_10 = svi_result.params