Purpose¶

Figure 2.29: (left) Overall negative log probability for the original model and the model with learned guess probabilities. The lower red bar indicates that learning the guess probabilities gives a substantially better model, according to this metric. (right) Negative log probability for each skill, showing that the improvement varies from skill to skill.

• The figure is not the log_density of the model it is the negative log probability of the ground truth. For a participant with $skill_i$ the negative log probability is .
$$-log(p(skill_i = truth_i))$$

where $truth_i$ is an indicator variable of having $skill_i$ and the probability of each skill is $p(skill_i)$ ~ $Bernoulli(\theta_i)$

Further details from the text:

A common metric to use is the probability of the ground truth values under the inferred distributions. Sometimes it is convenient to take the logarithm of the probability, since this gives a more manageable number when the probability is very small. When we use the logarithm of the probability, the metric is referred to as the log probability. So, if the inferred probability of a person having a particular skill is $p$, then the log probability is $log(p)$ if the person has the skill and $log(1−p)$ if they don’t. If the person does have the skill then the best possible prediction is $p=1.0$, which gives log probability of $log(1.0)=0$ (the logarithm of one is zero). A less confident prediction, such as $p=0.8$ will give a log probability with a negative value, in this case $log(0.8)=−0.097$. The worst possible prediction of $p=0.0$ gives a log probability of negative infinity. ...

In [1]:
import operator
from functools import reduce
from typing import Callable, Dict, List

import arviz as az
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy.stats
from jax.scipy.special import logsumexp

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, DiscreteHMCGibbs, log_likelihood
from numpyro.infer.util import log_density, potential_energy

In [2]:
%matplotlib inline

In [3]:
%watermark -v -m -p arviz,jax,matplotlib,numpy,pandas,scipy,numpyro

Python implementation: CPython
Python version       : 3.8.11
IPython version      : 7.18.1

arviz     : 0.11.2
jax       : 0.2.19
matplotlib: 3.4.3
numpy     : 1.20.3
pandas    : 1.3.2
scipy     : 1.6.2
numpyro   : 0.7.2

Compiler    : GCC 7.5.0
OS          : Linux
Release     : 4.19.193-1-MANJARO
Machine     : x86_64
Processor   :
CPU cores   : 4
Architecture: 64bit


In [4]:
%watermark -gb

Git hash: 307321cc497d1542d2908d60950823b102b16219

Git branch: master


In [5]:
def neg_log_proba_score(theta: np.array, y_true: np.array):
"""
Calculates the the negative log probability of the ground truth, the self assessed skills.
:param theta np.array: array of beta probabilities
:param y_true np.array, dtype == int: array of indicator variables for skill of participants
"""
assert theta.shape == y_true.shape
assert np.issubdtype(y_true.dtype, np.integer)
score = scipy.stats.bernoulli(theta).pmf(y_true)
score[score == 0.0] = np.finfo(float).eps
return -np.log(score)

def plot_bars(
data: np.array,
columns: List[str],
index: List[str],
ax=None,
tick_step=0.5,
**kwargs,
):
if ax is None:
fig, ax = plt.subplots()
else:
fig = None

pd.DataFrame(data, columns=columns, index=index).plot(
kind="bar", color=["b", "r"], ax=ax, zorder=3, **kwargs
)
ax.grid(zorder=0, axis="y")
ax.yaxis.set_ticks(np.arange(0, data.max(), tick_step));


Log pointwise predictive density from log_likelihood¶

signature: log_likelihood(model, posterior_samples, *args, parallel=False, batch_ndims=1, **kwargs)

In [6]:
def log_ppd(
model: Callable,
posterior_samples: Dict,
*args,
parallel=False,
batch_ndims=1,
**kwargs
):
"""
Log pointwise predictive density
:param model Callable: Python callable containing Pyro primitives
:param posterior_samples Dict: dictionary of samples from the posterior.
:param args: model arguments
:param parallel bool: passed to log_likelihood from numpyro.infer
:param batch_ndims Union[0, 1, 2]: passed to log_likelihood from numpyro.infer, see log_likelihood for details
:param kwargs: model kwargs
"""
post_loglik = log_likelihood(
model,
posterior_samples,
*args,
parallel=parallel,
batch_ndims=batch_ndims,
**kwargs
)
post_loglik_res = np.concatenate(
[obs[:, None] for obs in post_loglik.values()], axis=1
)
exp_log_density = logsumexp(post_loglik_res, axis=0) - jnp.log(
jnp.shape(post_loglik_res)[0]
)
return exp_log_density

In [7]:
rng_key = jax.random.PRNGKey(2)


Get Data¶

In [8]:
raw_data = pd.read_csv(
)
self_assessed = raw_data.iloc[1:, 1:8].copy()
self_assessed = self_assessed.astype(int)

)
skills_needed = []
for index, row in skills_key.iterrows():
skills_needed.append([i for i, x in enumerate(row) if x])

)

responses = responses.astype("int32")


Without plates¶

Define models and run inference¶

In [9]:
def model_00(
):
n_skills = max(map(max, skills_needed)) + 1

participants_plate = numpyro.plate("participants_plate", n_participants)

with participants_plate:
skills = []
for s in range(n_skills):
skills.append(numpyro.sample("skill_{}".format(s), dist.Bernoulli(0.5)))

for q in range(n_questions):
has_skills = reduce(operator.mul, [skills[i] for i in skills_needed[q]])
prob_correct = has_skills * (1 - prob_mistake) + (1 - has_skills) * prob_guess
isCorrect = numpyro.sample(
"isCorrect_{}".format(q),
dist.Bernoulli(prob_correct).to_event(1),
)

In [10]:
nuts_kernel = NUTS(model_00)

kernel = DiscreteHMCGibbs(nuts_kernel, modified=True)

mcmc_00 = MCMC(
kernel, num_warmup=200, num_samples=1000, num_chains=4, jit_model_args=False
)
mcmc_00.run(
rng_key,
jnp.array(responses),
skills_needed,
extra_fields=(
"z",
"hmc_state.potential_energy",
"hmc_state.z",
"rng_key",
"hmc_state.rng_key",
),
)
mcmc_00.print_summary()

/home/benda/anaconda3/envs/numpyro_play/lib/python3.8/site-packages/numpyro/infer/mcmc.py:269: UserWarning: There are not enough devices to run parallel chains: expected 4 but got 1. Chains will be drawn sequentially. If you are running MCMC in CPU, consider using numpyro.set_host_device_count(4) at the beginning of your program. You can double-check how many devices are available in your system using jax.local_device_count().
warnings.warn(
sample: 100%|██████████| 1200/1200 [02:02<00:00,  9.83it/s, 1 steps of size 1.19e+37. acc. prob=1.00]
sample: 100%|██████████| 1200/1200 [02:03<00:00,  9.70it/s, 1 steps of size 1.19e+37. acc. prob=1.00]
sample: 100%|██████████| 1200/1200 [02:03<00:00,  9.73it/s, 1 steps of size 1.19e+37. acc. prob=1.00]
sample: 100%|██████████| 1200/1200 [02:02<00:00,  9.83it/s, 1 steps of size 1.19e+37. acc. prob=1.00]

                 mean       std    median      5.0%     95.0%     n_eff     r_hat
skill_0[0]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_0[1]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_0[2]      0.01      0.10      0.00      0.00      0.00   4093.80      1.00
skill_0[3]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_0[4]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_0[5]      0.00      0.00      0.00      0.00      0.00       nan       nan
skill_0[6]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_0[7]      1.00      0.02      1.00      1.00      1.00       nan      1.00
skill_0[8]      0.99      0.12      1.00      1.00      1.00   4130.52      1.00
skill_0[9]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_0[10]      1.00      0.02      1.00      1.00      1.00       nan      1.00
skill_0[11]      0.97      0.18      1.00      1.00      1.00   4311.42      1.00
skill_0[12]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_0[13]      0.66      0.47      1.00      0.00      1.00  12611.01      1.00
skill_0[14]      1.00      0.02      1.00      1.00      1.00       nan      1.00
skill_0[15]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_0[16]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_0[17]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_0[18]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_0[19]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_0[20]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_0[21]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_1[0]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_1[1]      1.00      0.07      1.00      1.00      1.00   4043.37      1.00
skill_1[2]      0.00      0.07      0.00      0.00      0.00   3397.09      1.00
skill_1[3]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_1[4]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_1[5]      0.00      0.06      0.00      0.00      0.00   4039.58      1.00
skill_1[6]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_1[7]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_1[8]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_1[9]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_1[10]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_1[11]      0.94      0.24      1.00      1.00      1.00   3595.70      1.00
skill_1[12]      1.00      0.02      1.00      1.00      1.00       nan      1.00
skill_1[13]      1.00      0.02      1.00      1.00      1.00       nan      1.00
skill_1[14]      0.97      0.17      1.00      1.00      1.00   3986.16      1.00
skill_1[15]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_1[16]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_1[17]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_1[18]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_1[19]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_1[20]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_1[21]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_2[0]      0.98      0.15      1.00      1.00      1.00   4195.02      1.00
skill_2[1]      0.59      0.49      1.00      0.00      1.00  27054.31      1.00
skill_2[2]      0.04      0.20      0.00      0.00      0.00   4370.98      1.00
skill_2[3]      0.58      0.49      1.00      0.00      1.00  27689.66      1.00
skill_2[4]      0.98      0.13      1.00      1.00      1.00   4140.87      1.00
skill_2[5]      0.00      0.00      0.00      0.00      0.00       nan       nan
skill_2[6]      0.98      0.14      1.00      1.00      1.00   3888.96      1.00
skill_2[7]      1.00      0.02      1.00      1.00      1.00       nan      1.00
skill_2[8]      0.98      0.13      1.00      1.00      1.00   4056.76      1.00
skill_2[9]      0.04      0.20      0.00      0.00      0.00   4393.83      1.00
skill_2[10]      0.58      0.49      1.00      0.00      1.00  29403.65      1.00
skill_2[11]      0.98      0.14      1.00      1.00      1.00   3875.56      1.00
skill_2[12]      0.98      0.14      1.00      1.00      1.00   4166.72      1.00
skill_2[13]      0.98      0.15      1.00      1.00      1.00   4175.88      1.00
skill_2[14]      0.58      0.49      1.00      0.00      1.00  37696.08      1.00
skill_2[15]      0.59      0.49      1.00      0.00      1.00  24106.45      1.00
skill_2[16]      0.58      0.49      1.00      0.00      1.00  31530.27      1.00
skill_2[17]      0.98      0.14      1.00      1.00      1.00   4186.04      1.00
skill_2[18]      0.98      0.14      1.00      1.00      1.00   4171.18      1.00
skill_2[19]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_2[20]      0.98      0.13      1.00      1.00      1.00   4110.53      1.00
skill_2[21]      0.58      0.49      1.00      0.00      1.00  24811.19      1.00
skill_3[0]      0.99      0.08      1.00      1.00      1.00   3895.34      1.00
skill_3[1]      0.99      0.08      1.00      1.00      1.00   4060.72      1.00
skill_3[2]      0.99      0.09      1.00      1.00      1.00   3623.58      1.00
skill_3[3]      0.99      0.09      1.00      1.00      1.00   3235.95      1.00
skill_3[4]      0.78      0.41      1.00      0.00      1.00   6003.05      1.00
skill_3[5]      0.00      0.00      0.00      0.00      0.00       nan       nan
skill_3[6]      0.99      0.10      1.00      1.00      1.00   4033.33      1.00
skill_3[7]      0.78      0.42      1.00      0.00      1.00   7574.41      1.00
skill_3[8]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_3[9]      1.00      0.02      1.00      1.00      1.00       nan      1.00
skill_3[10]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_3[11]      0.00      0.04      0.00      0.00      0.00   4022.87      1.00
skill_3[12]      0.99      0.09      1.00      1.00      1.00   3973.96      1.00
skill_3[13]      0.99      0.09      1.00      1.00      1.00   3833.68      1.00
skill_3[14]      0.99      0.09      1.00      1.00      1.00   4075.71      1.00
skill_3[15]      1.00      0.02      1.00      1.00      1.00       nan      1.00
skill_3[16]      0.99      0.08      1.00      1.00      1.00   4055.73      1.00
skill_3[17]      0.99      0.09      1.00      1.00      1.00   4076.82      1.00
skill_3[18]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_3[19]      1.00      0.02      1.00      1.00      1.00       nan      1.00
skill_3[20]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_3[21]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_4[0]      1.00      0.02      1.00      1.00      1.00       nan      1.00
skill_4[1]      0.15      0.35      0.00      0.00      1.00   5860.67      1.00
skill_4[2]      0.59      0.49      1.00      0.00      1.00  21820.18      1.00
skill_4[3]      0.15      0.36      0.00      0.00      1.00   5949.44      1.00
skill_4[4]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_4[5]      0.00      0.03      0.00      0.00      0.00       nan      1.00
skill_4[6]      1.00      0.06      1.00      1.00      1.00   4036.51      1.00
skill_4[7]      0.15      0.36      0.00      0.00      1.00   5670.10      1.00
skill_4[8]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_4[9]      1.00      0.02      1.00      1.00      1.00       nan      1.00
skill_4[10]      0.86      0.35      1.00      0.00      1.00   5786.31      1.00
skill_4[11]      1.00      0.07      1.00      1.00      1.00   3728.74      1.00
skill_4[12]      0.87      0.34      1.00      0.00      1.00   4455.87      1.00
skill_4[13]      0.15      0.36      0.00      0.00      1.00   5889.51      1.00
skill_4[14]      0.86      0.34      1.00      0.00      1.00   5095.11      1.00
skill_4[15]      0.86      0.35      1.00      0.00      1.00   5561.46      1.00
skill_4[16]      1.00      0.07      1.00      1.00      1.00   4041.21      1.00
skill_4[17]      1.00      0.06      1.00      1.00      1.00   4035.86      1.00
skill_4[18]      0.86      0.34      1.00      0.00      1.00   5361.01      1.00
skill_4[19]      0.99      0.08      1.00      1.00      1.00   4053.58      1.00
skill_4[20]      1.00      0.02      1.00      1.00      1.00       nan      1.00
skill_4[21]      1.00      0.06      1.00      1.00      1.00   4035.14      1.00
skill_5[0]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_5[1]      0.00      0.00      0.00      0.00      0.00       nan       nan
skill_5[2]      0.78      0.42      1.00      0.00      1.00   7042.74      1.00
skill_5[3]      0.99      0.09      1.00      1.00      1.00   4074.01      1.00
skill_5[4]      0.99      0.08      1.00      1.00      1.00   3818.77      1.00
skill_5[5]      0.00      0.00      0.00      0.00      0.00       nan       nan
skill_5[6]      1.00      0.02      1.00      1.00      1.00       nan      1.00
skill_5[7]      1.00      0.02      1.00      1.00      1.00       nan      1.00
skill_5[8]      1.00      0.02      1.00      1.00      1.00       nan      1.00
skill_5[9]      0.99      0.08      1.00      1.00      1.00   3627.43      1.00
skill_5[10]      0.78      0.41      1.00      0.00      1.00   7519.46      1.00
skill_5[11]      0.00      0.02      0.00      0.00      0.00       nan      1.00
skill_5[12]      1.00      0.02      1.00      1.00      1.00       nan      1.00
skill_5[13]      0.78      0.42      1.00      0.00      1.00   7063.90      1.00
skill_5[14]      1.00      0.02      1.00      1.00      1.00       nan      1.00
skill_5[15]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_5[16]      1.00      0.02      1.00      1.00      1.00       nan      1.00
skill_5[17]      0.78      0.41      1.00      0.00      1.00   7747.84      1.00
skill_5[18]      0.78      0.42      1.00      0.00      1.00   6926.57      1.00
skill_5[19]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_5[20]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_5[21]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_6[0]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_6[1]      0.99      0.10      1.00      1.00      1.00   4020.57      1.00
skill_6[2]      0.50      0.50      0.00      0.00      1.00  -7660.98      1.00
skill_6[3]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_6[4]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_6[5]      0.50      0.50      0.00      0.00      1.00  -4056.58      1.00
skill_6[6]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_6[7]      1.00      0.02      1.00      1.00      1.00       nan      1.00
skill_6[8]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_6[9]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_6[10]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_6[11]      0.09      0.29      0.00      0.00      0.00   2600.06      1.00
skill_6[12]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_6[13]      0.99      0.09      1.00      1.00      1.00   4078.84      1.00
skill_6[14]      0.99      0.11      1.00      1.00      1.00   4109.17      1.00
skill_6[15]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_6[16]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_6[17]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_6[18]      0.99      0.12      1.00      1.00      1.00   4128.28      1.00
skill_6[19]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_6[20]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_6[21]      1.00      0.02      1.00      1.00      1.00       nan      1.00


In [11]:
ds = az.from_numpyro(mcmc_00)

In [12]:
az.plot_trace(ds);

In [13]:
log_density_model_00, model_00_trace = log_density(
model_00,
(jnp.array(responses), skills_needed),
dict(prob_mistake=0.1, prob_guess=0.2),
{key: value.mean(0) for key, value in mcmc_00.get_samples().items()},
)

In [14]:
pe_model_00 = mcmc_00.get_extra_fields()["hmc_state.potential_energy"]

In [15]:
exp_log_density_00 = log_ppd(
model_00, mcmc_00.get_samples(), jnp.array(responses), skills_needed
)

In [16]:
# post_loglik_00 = log_likelihood(
#     model_00, mcmc_00.get_samples(), jnp.array(responses), skills_needed,
# )
# post_loglik_00_res = np.concatenate(
#     [obs[:, None] for obs in post_loglik_00.values()], axis=1
# )
# exp_log_density_00 = logsumexp(post_loglik_00_res, axis=0) - jnp.log(
#     jnp.shape(post_loglik_00_res)[0]
# )

In [17]:
theta_model_00 = np.zeros((22, 7))
for i, param in enumerate(["skill_" + str(i) for i in range(7)]):
theta_model_00[:, i] = np.mean(mcmc_00.get_samples()[param], axis=0)

neg_log_proba_model_00 = neg_log_proba_score(theta_model_00, self_assessed.values)

In [18]:
def model_02(
):
n_skills = max(map(max, skills_needed)) + 1

with numpyro.plate("questions_plate", n_questions):
prob_guess = numpyro.sample("prob_guess", dist.Beta(2.5, 7.5))

participants_plate = numpyro.plate("participants_plate", n_participants)

with participants_plate:
skills = []
for s in range(n_skills):
skills.append(numpyro.sample("skill_{}".format(s), dist.Bernoulli(0.5)))

for q in range(n_questions):
has_skills = reduce(operator.mul, [skills[i] for i in skills_needed[q]])
prob_correct = (
has_skills * (1 - prob_mistake) + (1 - has_skills) * prob_guess[q]
)
isCorrect = numpyro.sample(
"isCorrect_{}".format(q),
dist.Bernoulli(prob_correct).to_event(1),
)

In [19]:
nuts_kernel = NUTS(model_02)

kernel = DiscreteHMCGibbs(nuts_kernel, modified=True)

mcmc_02 = MCMC(kernel, num_warmup=200, num_samples=1000, num_chains=4)
mcmc_02.run(
rng_key,
jnp.array(responses),
skills_needed,
extra_fields=(
"z",
"hmc_state.potential_energy",
"hmc_state.z",
"rng_key",
"hmc_state.rng_key",
),
)
mcmc_02.print_summary()

/home/benda/anaconda3/envs/numpyro_play/lib/python3.8/site-packages/numpyro/infer/mcmc.py:269: UserWarning: There are not enough devices to run parallel chains: expected 4 but got 1. Chains will be drawn sequentially. If you are running MCMC in CPU, consider using numpyro.set_host_device_count(4) at the beginning of your program. You can double-check how many devices are available in your system using jax.local_device_count().
warnings.warn(
sample: 100%|██████████| 1200/1200 [02:19<00:00,  8.60it/s, 7 steps of size 4.34e-01. acc. prob=0.88]
sample: 100%|██████████| 1200/1200 [02:19<00:00,  8.59it/s, 7 steps of size 4.57e-01. acc. prob=0.86]
sample: 100%|██████████| 1200/1200 [02:21<00:00,  8.47it/s, 7 steps of size 4.90e-01. acc. prob=0.86]
sample: 100%|██████████| 1200/1200 [02:23<00:00,  8.35it/s, 7 steps of size 4.47e-01. acc. prob=0.86]

                    mean       std    median      5.0%     95.0%     n_eff     r_hat
prob_guess[0]      0.26      0.12      0.25      0.05      0.43   3138.97      1.00
prob_guess[1]      0.28      0.12      0.27      0.08      0.46   3689.75      1.00
prob_guess[2]      0.32      0.13      0.31      0.10      0.51   5050.20      1.00
prob_guess[3]      0.33      0.13      0.32      0.11      0.53   3007.40      1.00
prob_guess[4]      0.29      0.12      0.28      0.09      0.48   2769.02      1.00
prob_guess[5]      0.20      0.11      0.19      0.04      0.38   5336.04      1.00
prob_guess[6]      0.39      0.13      0.39      0.18      0.60   4493.49      1.00
prob_guess[7]      0.38      0.13      0.37      0.16      0.59   5009.81      1.00
prob_guess[8]      0.39      0.13      0.38      0.17      0.59   3560.31      1.00
prob_guess[9]      0.32      0.13      0.31      0.10      0.52   3133.14      1.00
prob_guess[10]      0.30      0.13      0.29      0.09      0.49   3694.41      1.00
prob_guess[11]      0.19      0.11      0.18      0.03      0.35   4524.02      1.00
prob_guess[12]      0.23      0.12      0.22      0.04      0.41   3143.92      1.00
prob_guess[13]      0.20      0.11      0.19      0.03      0.36   5236.74      1.00
prob_guess[14]      0.22      0.11      0.21      0.04      0.39   3761.43      1.00
prob_guess[15]      0.35      0.12      0.34      0.13      0.54   3402.84      1.00
prob_guess[16]      0.31      0.12      0.30      0.10      0.49   4800.58      1.00
prob_guess[17]      0.24      0.12      0.23      0.05      0.43   4647.01      1.00
prob_guess[18]      0.61      0.11      0.61      0.44      0.78   3861.77      1.00
prob_guess[19]      0.56      0.11      0.56      0.38      0.73   4028.99      1.00
prob_guess[20]      0.51      0.11      0.51      0.33      0.70   3595.62      1.00
prob_guess[21]      0.11      0.07      0.10      0.01      0.21   4919.36      1.00
prob_guess[22]      0.44      0.11      0.44      0.27      0.63   3543.74      1.00
prob_guess[23]      0.28      0.13      0.27      0.07      0.47   3130.95      1.00
prob_guess[24]      0.26      0.12      0.25      0.07      0.46   2904.73      1.00
prob_guess[25]      0.34      0.13      0.33      0.11      0.52   4288.83      1.00
prob_guess[26]      0.52      0.15      0.53      0.29      0.76    958.60      1.00
prob_guess[27]      0.52      0.14      0.54      0.29      0.76   1093.10      1.00
prob_guess[28]      0.52      0.15      0.53      0.30      0.77   1114.17      1.00
prob_guess[29]      0.20      0.10      0.18      0.04      0.34   1932.30      1.00
prob_guess[30]      0.23      0.11      0.22      0.05      0.40   3297.66      1.00
prob_guess[31]      0.31      0.13      0.31      0.07      0.51   1391.98      1.00
prob_guess[32]      0.45      0.15      0.46      0.21      0.70   1048.96      1.01
prob_guess[33]      0.47      0.15      0.48      0.21      0.71   1007.86      1.00
prob_guess[34]      0.33      0.11      0.32      0.15      0.52   3329.14      1.00
prob_guess[35]      0.41      0.12      0.41      0.23      0.61   3670.40      1.00
prob_guess[36]      0.38      0.12      0.38      0.19      0.57   4165.09      1.00
prob_guess[37]      0.53      0.12      0.53      0.33      0.73   4196.24      1.00
prob_guess[38]      0.46      0.12      0.46      0.26      0.66   3673.99      1.00
prob_guess[39]      0.23      0.10      0.22      0.07      0.40   3564.23      1.00
prob_guess[40]      0.43      0.13      0.43      0.22      0.65   3616.51      1.00
prob_guess[41]      0.43      0.13      0.43      0.23      0.64   3123.85      1.00
prob_guess[42]      0.37      0.13      0.37      0.16      0.58   2582.66      1.00
prob_guess[43]      0.30      0.12      0.29      0.10      0.48   4110.90      1.00
prob_guess[44]      0.30      0.11      0.29      0.12      0.48   4875.23      1.00
prob_guess[45]      0.31      0.11      0.30      0.12      0.49   3179.36      1.00
prob_guess[46]      0.19      0.10      0.17      0.03      0.33   4516.82      1.00
prob_guess[47]      0.31      0.12      0.30      0.13      0.51   5223.94      1.00
skill_0[0]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_0[1]      1.00      0.03      1.00      1.00      1.00       nan      1.00
skill_0[2]      0.01      0.10      0.00      0.00      0.00   3876.81      1.00
skill_0[3]      1.00      0.04      1.00      1.00      1.00       nan      1.00
skill_0[4]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_0[5]      0.00      0.00      0.00      0.00      0.00       nan       nan
skill_0[6]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_0[7]      1.00      0.07      1.00      1.00      1.00   4041.91      1.00
skill_0[8]      0.92      0.27      1.00      1.00      1.00   3931.25      1.00
skill_0[9]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_0[10]      0.99      0.11      1.00      1.00      1.00   4110.15      1.00
skill_0[11]      0.89      0.31      1.00      0.00      1.00   4984.72      1.00
skill_0[12]      1.00      0.03      1.00      1.00      1.00       nan      1.00
skill_0[13]      0.28      0.45      0.00      0.00      1.00   3007.28      1.00
skill_0[14]      0.96      0.19      1.00      1.00      1.00   3995.70      1.00
skill_0[15]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_0[16]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_0[17]      1.00      0.03      1.00      1.00      1.00       nan      1.00
skill_0[18]      0.99      0.09      1.00      1.00      1.00   3971.27      1.00
skill_0[19]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_0[20]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_0[21]      1.00      0.02      1.00      1.00      1.00       nan      1.00
skill_1[0]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_1[1]      0.98      0.15      1.00      1.00      1.00   3746.90      1.00
skill_1[2]      0.00      0.07      0.00      0.00      0.00   4040.30      1.00
skill_1[3]      1.00      0.02      1.00      1.00      1.00       nan      1.00
skill_1[4]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_1[5]      0.00      0.06      0.00      0.00      0.00   4033.86      1.00
skill_1[6]      1.00      0.02      1.00      1.00      1.00       nan      1.00
skill_1[7]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_1[8]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_1[9]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_1[10]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_1[11]      0.81      0.39      1.00      0.00      1.00   1393.43      1.00
skill_1[12]      1.00      0.02      1.00      1.00      1.00       nan      1.00
skill_1[13]      1.00      0.05      1.00      1.00      1.00   4029.22      1.00
skill_1[14]      0.71      0.45      1.00      0.00      1.00   3865.86      1.00
skill_1[15]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_1[16]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_1[17]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_1[18]      1.00      0.04      1.00      1.00      1.00       nan      1.00
skill_1[19]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_1[20]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_1[21]      1.00      0.07      1.00      1.00      1.00   4045.54      1.00
skill_2[0]      0.51      0.50      1.00      0.00      1.00  13562.32      1.00
skill_2[1]      0.08      0.28      0.00      0.00      0.00   4902.13      1.00
skill_2[2]      0.01      0.11      0.00      0.00      0.00   3871.70      1.00
skill_2[3]      0.12      0.32      0.00      0.00      1.00   5291.34      1.00
skill_2[4]      0.51      0.50      1.00      0.00      1.00  12431.32      1.00
skill_2[5]      0.00      0.03      0.00      0.00      0.00       nan      1.00
skill_2[6]      0.51      0.50      1.00      0.00      1.00  20899.36      1.00
skill_2[7]      0.99      0.12      1.00      1.00      1.00   4125.13      1.00
skill_2[8]      0.52      0.50      1.00      0.00      1.00  14466.44      1.00
skill_2[9]      0.02      0.14      0.00      0.00      0.00   4122.87      1.00
skill_2[10]      0.09      0.29      0.00      0.00      0.00   4396.90      1.00
skill_2[11]      0.51      0.50      1.00      0.00      1.00  18257.86      1.00
skill_2[12]      0.51      0.50      1.00      0.00      1.00  13388.88      1.00
skill_2[13]      0.89      0.32      1.00      0.00      1.00   5149.99      1.00
skill_2[14]      0.14      0.34      0.00      0.00      1.00   5471.75      1.00
skill_2[15]      0.09      0.28      0.00      0.00      0.00   4937.10      1.00
skill_2[16]      0.18      0.38      0.00      0.00      1.00   5614.03      1.00
skill_2[17]      0.51      0.50      1.00      0.00      1.00  23114.37      1.00
skill_2[18]      0.51      0.50      1.00      0.00      1.00  20113.88      1.00
skill_2[19]      0.99      0.12      1.00      1.00      1.00   4116.94      1.00
skill_2[20]      0.51      0.50      1.00      0.00      1.00  15794.23      1.00
skill_2[21]      0.10      0.30      0.00      0.00      0.00   4538.81      1.00
skill_3[0]      0.52      0.50      1.00      0.00      1.00   1627.99      1.00
skill_3[1]      0.62      0.49      1.00      0.00      1.00   2340.25      1.00
skill_3[2]      0.61      0.49      1.00      0.00      1.00   2249.26      1.00
skill_3[3]      0.73      0.44      1.00      0.00      1.00   2846.19      1.00
skill_3[4]      0.11      0.31      0.00      0.00      1.00   1697.16      1.00
skill_3[5]      0.00      0.02      0.00      0.00      0.00       nan      1.00
skill_3[6]      0.52      0.50      1.00      0.00      1.00   1757.86      1.00
skill_3[7]      0.11      0.31      0.00      0.00      1.00   1598.81      1.00
skill_3[8]      0.94      0.24      1.00      1.00      1.00   3779.98      1.00
skill_3[9]      0.94      0.23      1.00      1.00      1.00   3426.03      1.00
skill_3[10]      0.94      0.24      1.00      1.00      1.00   3526.44      1.00
skill_3[11]      0.00      0.03      0.00      0.00      0.00       nan      1.00
skill_3[12]      0.52      0.50      1.00      0.00      1.00   1658.32      1.00
skill_3[13]      0.61      0.49      1.00      0.00      1.00   2149.84      1.00
skill_3[14]      0.52      0.50      1.00      0.00      1.00   1790.86      1.00
skill_3[15]      1.00      0.03      1.00      1.00      1.00       nan      1.00
skill_3[16]      0.52      0.50      1.00      0.00      1.00   1769.12      1.00
skill_3[17]      0.52      0.50      1.00      0.00      1.00   1683.65      1.00
skill_3[18]      0.94      0.23      1.00      1.00      1.00   3907.40      1.00
skill_3[19]      0.94      0.23      1.00      1.00      1.00   4002.44      1.00
skill_3[20]      0.95      0.23      1.00      1.00      1.00   3661.50      1.00
skill_3[21]      0.94      0.23      1.00      1.00      1.00   3765.80      1.00
skill_4[0]      0.99      0.08      1.00      1.00      1.00   4060.42      1.00
skill_4[1]      0.09      0.29      0.00      0.00      0.00   4379.23      1.00
skill_4[2]      0.49      0.50      0.00      0.00      1.00  10748.09      1.00
skill_4[3]      0.04      0.20      0.00      0.00      0.00   4252.62      1.00
skill_4[4]      0.99      0.07      1.00      1.00      1.00   4049.81      1.00
skill_4[5]      0.00      0.03      0.00      0.00      0.00       nan      1.00
skill_4[6]      0.91      0.29      1.00      1.00      1.00   4622.34      1.00
skill_4[7]      0.07      0.26      0.00      0.00      0.00   4727.83      1.00
skill_4[8]      0.99      0.07      1.00      1.00      1.00   3776.96      1.00
skill_4[9]      0.99      0.09      1.00      1.00      1.00   4067.26      1.00
skill_4[10]      0.39      0.49      0.00      0.00      1.00   8301.22      1.00
skill_4[11]      0.89      0.31      1.00      0.00      1.00   4541.33      1.00
skill_4[12]      0.38      0.49      0.00      0.00      1.00   7387.72      1.00
skill_4[13]      0.06      0.23      0.00      0.00      0.00   4535.34      1.00
skill_4[14]      0.27      0.44      0.00      0.00      1.00   5097.39      1.00
skill_4[15]      0.31      0.46      0.00      0.00      1.00   7812.16      1.00
skill_4[16]      0.87      0.34      1.00      0.00      1.00   5593.14      1.00
skill_4[17]      0.93      0.26      1.00      1.00      1.00   4655.86      1.00
skill_4[18]      0.31      0.46      0.00      0.00      1.00   7551.88      1.00
skill_4[19]      0.91      0.29      1.00      1.00      1.00   4885.18      1.00
skill_4[20]      0.99      0.08      1.00      1.00      1.00   4057.65      1.00
skill_4[21]      0.91      0.28      1.00      1.00      1.00   4967.61      1.00
skill_5[0]      1.00      0.02      1.00      1.00      1.00       nan      1.00
skill_5[1]      0.00      0.00      0.00      0.00      0.00       nan       nan
skill_5[2]      0.29      0.45      0.00      0.00      1.00   4111.53      1.00
skill_5[3]      0.83      0.38      1.00      0.00      1.00   5250.81      1.00
skill_5[4]      0.90      0.30      1.00      1.00      1.00   4413.22      1.00
skill_5[5]      0.00      0.00      0.00      0.00      0.00       nan       nan
skill_5[6]      1.00      0.02      1.00      1.00      1.00       nan      1.00
skill_5[7]      0.99      0.07      1.00      1.00      1.00   4051.40      1.00
skill_5[8]      1.00      0.05      1.00      1.00      1.00   4026.94      1.00
skill_5[9]      0.83      0.37      1.00      0.00      1.00   5908.88      1.00
skill_5[10]      0.43      0.50      0.00      0.00      1.00   6477.28      1.00
skill_5[11]      0.00      0.02      0.00      0.00      0.00       nan      1.00
skill_5[12]      0.99      0.11      1.00      1.00      1.00   4104.54      1.00
skill_5[13]      0.30      0.46      0.00      0.00      1.00   4174.26      1.00
skill_5[14]      0.99      0.10      1.00      1.00      1.00   4099.45      1.00
skill_5[15]      1.00      0.02      1.00      1.00      1.00       nan      1.00
skill_5[16]      0.99      0.10      1.00      1.00      1.00   4096.47      1.00
skill_5[17]      0.28      0.45      0.00      0.00      1.00   4807.04      1.00
skill_5[18]      0.29      0.45      0.00      0.00      1.00   4008.14      1.00
skill_5[19]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_5[20]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_5[21]      1.00      0.02      1.00      1.00      1.00       nan      1.00
skill_6[0]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_6[1]      0.85      0.36      1.00      0.00      1.00   4567.60      1.00
skill_6[2]      0.50      0.50      0.00      0.00      1.00  -9642.91      1.00
skill_6[3]      1.00      0.02      1.00      1.00      1.00       nan      1.00
skill_6[4]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_6[5]      0.50      0.50      0.00      0.00      1.00  -4032.19      1.00
skill_6[6]      1.00      0.04      1.00      1.00      1.00   4023.54      1.00
skill_6[7]      0.99      0.08      1.00      1.00      1.00   3523.17      1.00
skill_6[8]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_6[9]      1.00      0.02      1.00      1.00      1.00       nan      1.00
skill_6[10]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_6[11]      0.11      0.31      0.00      0.00      1.00   1995.29      1.00
skill_6[12]      1.00      0.02      1.00      1.00      1.00       nan      1.00
skill_6[13]      0.84      0.36      1.00      0.00      1.00   4683.67      1.00
skill_6[14]      0.82      0.39      1.00      0.00      1.00   3364.55      1.00
skill_6[15]      1.00      0.02      1.00      1.00      1.00       nan      1.00
skill_6[16]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_6[17]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_6[18]      0.87      0.33      1.00      0.00      1.00   4571.07      1.00
skill_6[19]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_6[20]      1.00      0.02      1.00      1.00      1.00       nan      1.00
skill_6[21]      0.99      0.12      1.00      1.00      1.00   3642.66      1.00


In [20]:
ds = az.from_numpyro(mcmc_02)

In [21]:
az.plot_trace(ds);