Purpose

image.png image.png

Figure 2.29: (left) Overall negative log probability for the original model and the model with learned guess probabilities. The lower red bar indicates that learning the guess probabilities gives a substantially better model, according to this metric. (right) Negative log probability for each skill, showing that the improvement varies from skill to skill.

  • The figure is not the log_density of the model it is the negative log probability of the ground truth. For a participant with $skill_i$ the negative log probability is .
$$-log(p(skill_i = truth_i))$$

where $truth_i$ is an indicator variable of having $skill_i$ and the probability of each skill is $p(skill_i)$ ~ $Bernoulli(\theta_i)$

Further details from the text:

A common metric to use is the probability of the ground truth values under the inferred distributions. Sometimes it is convenient to take the logarithm of the probability, since this gives a more manageable number when the probability is very small. When we use the logarithm of the probability, the metric is referred to as the log probability. So, if the inferred probability of a person having a particular skill is $p$, then the log probability is $log(p)$ if the person has the skill and $log(1−p)$ if they don’t. If the person does have the skill then the best possible prediction is $p=1.0$, which gives log probability of $log(1.0)=0$ (the logarithm of one is zero). A less confident prediction, such as $p=0.8$ will give a log probability with a negative value, in this case $log(0.8)=−0.097$. The worst possible prediction of $p=0.0$ gives a log probability of negative infinity. ...

In [1]:
import operator
from functools import reduce
from typing import Callable, Dict, List

import arviz as az
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy.stats
from jax.scipy.special import logsumexp

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, DiscreteHMCGibbs, log_likelihood
from numpyro.infer.util import log_density, potential_energy
In [2]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2
%load_ext watermark
In [3]:
%watermark -v -m -p arviz,jax,matplotlib,numpy,pandas,scipy,numpyro
Python implementation: CPython
Python version       : 3.8.11
IPython version      : 7.18.1

arviz     : 0.11.2
jax       : 0.2.19
matplotlib: 3.4.3
numpy     : 1.20.3
pandas    : 1.3.2
scipy     : 1.6.2
numpyro   : 0.7.2

Compiler    : GCC 7.5.0
OS          : Linux
Release     : 4.19.193-1-MANJARO
Machine     : x86_64
Processor   : 
CPU cores   : 4
Architecture: 64bit

In [4]:
%watermark -gb
Git hash: 307321cc497d1542d2908d60950823b102b16219

Git branch: master

In [5]:
def neg_log_proba_score(theta: np.array, y_true: np.array):
    """
    Calculates the the negative log probability of the ground truth, the self assessed skills.
    :param theta np.array: array of beta probabilities
    :param y_true np.array, dtype == int: array of indicator variables for skill of participants
    """
    assert theta.shape == y_true.shape
    assert np.issubdtype(y_true.dtype, np.integer)
    score = scipy.stats.bernoulli(theta).pmf(y_true)
    score[score == 0.0] = np.finfo(float).eps
    return -np.log(score)


def plot_bars(
    data: np.array,
    columns: List[str],
    index: List[str],
    ax=None,
    tick_step=0.5,
    **kwargs,
):
    if ax is None:
        fig, ax = plt.subplots()
    else:
        fig = None

    pd.DataFrame(data, columns=columns, index=index).plot(
        kind="bar", color=["b", "r"], ax=ax, zorder=3, **kwargs
    )
    ax.grid(zorder=0, axis="y")
    ax.yaxis.set_ticks(np.arange(0, data.max(), tick_step));

Log pointwise predictive density from log_likelihood

signature: log_likelihood(model, posterior_samples, *args, parallel=False, batch_ndims=1, **kwargs)

In [6]:
def log_ppd(
    model: Callable,
    posterior_samples: Dict,
    *args,
    parallel=False,
    batch_ndims=1,
    **kwargs
):
    """
    Log pointwise predictive density
    :param model Callable: Python callable containing Pyro primitives
    :param posterior_samples Dict: dictionary of samples from the posterior.
    :param args: model arguments
    :param parallel bool: passed to `log_likelihood` from numpyro.infer
    :param batch_ndims Union[0, 1, 2]: passed to `log_likelihood` from numpyro.infer, see `log_likelihood` for details
    :param kwargs: model kwargs
    """
    post_loglik = log_likelihood(
        model,
        posterior_samples,
        *args,
        parallel=parallel,
        batch_ndims=batch_ndims,
        **kwargs
    )
    post_loglik_res = np.concatenate(
        [obs[:, None] for obs in post_loglik.values()], axis=1
    )
    exp_log_density = logsumexp(post_loglik_res, axis=0) - jnp.log(
        jnp.shape(post_loglik_res)[0]
    )
    return exp_log_density
In [7]:
rng_key = jax.random.PRNGKey(2)

Get Data

In [8]:
raw_data = pd.read_csv(
    "http://www.mbmlbook.com/Downloads/LearningSkills_Real_Data_Experiments-Original-Inputs-RawResponsesAsDictionary.csv"
)
self_assessed = raw_data.iloc[1:, 1:8].copy()
self_assessed = self_assessed.astype(int)

skills_key = pd.read_csv(
    "http://www.mbmlbook.com/Downloads/LearningSkills_Real_Data_Experiments-Original-Inputs-Quiz-SkillsQuestionsMask.csv",
    header=None,
)
skills_needed = []
for index, row in skills_key.iterrows():
    skills_needed.append([i for i, x in enumerate(row) if x])

responses = pd.read_csv(
    "http://www.mbmlbook.com/Downloads/LearningSkills_Real_Data_Experiments-Original-Inputs-IsCorrect.csv",
    header=None,
)

responses = responses.astype("int32")

Without plates

Define models and run inference

In [9]:
def model_00(
    graded_responses, skills_needed: List[List[int]], prob_mistake=0.1, prob_guess=0.2
):
    n_questions, n_participants = graded_responses.shape
    n_skills = max(map(max, skills_needed)) + 1

    participants_plate = numpyro.plate("participants_plate", n_participants)

    with participants_plate:
        skills = []
        for s in range(n_skills):
            skills.append(numpyro.sample("skill_{}".format(s), dist.Bernoulli(0.5)))

    for q in range(n_questions):
        has_skills = reduce(operator.mul, [skills[i] for i in skills_needed[q]])
        prob_correct = has_skills * (1 - prob_mistake) + (1 - has_skills) * prob_guess
        isCorrect = numpyro.sample(
            "isCorrect_{}".format(q),
            dist.Bernoulli(prob_correct).to_event(1),
            obs=graded_responses[q],
        )
In [10]:
nuts_kernel = NUTS(model_00)

kernel = DiscreteHMCGibbs(nuts_kernel, modified=True)

mcmc_00 = MCMC(
    kernel, num_warmup=200, num_samples=1000, num_chains=4, jit_model_args=False
)
mcmc_00.run(
    rng_key,
    jnp.array(responses),
    skills_needed,
    extra_fields=(
        "z",
        "hmc_state.potential_energy",
        "hmc_state.z",
        "rng_key",
        "hmc_state.rng_key",
    ),
)
mcmc_00.print_summary()
/home/benda/anaconda3/envs/numpyro_play/lib/python3.8/site-packages/numpyro/infer/mcmc.py:269: UserWarning: There are not enough devices to run parallel chains: expected 4 but got 1. Chains will be drawn sequentially. If you are running MCMC in CPU, consider using `numpyro.set_host_device_count(4)` at the beginning of your program. You can double-check how many devices are available in your system using `jax.local_device_count()`.
  warnings.warn(
sample: 100%|██████████| 1200/1200 [02:02<00:00,  9.83it/s, 1 steps of size 1.19e+37. acc. prob=1.00] 
sample: 100%|██████████| 1200/1200 [02:03<00:00,  9.70it/s, 1 steps of size 1.19e+37. acc. prob=1.00] 
sample: 100%|██████████| 1200/1200 [02:03<00:00,  9.73it/s, 1 steps of size 1.19e+37. acc. prob=1.00] 
sample: 100%|██████████| 1200/1200 [02:02<00:00,  9.83it/s, 1 steps of size 1.19e+37. acc. prob=1.00] 
                 mean       std    median      5.0%     95.0%     n_eff     r_hat
 skill_0[0]      1.00      0.00      1.00      1.00      1.00       nan       nan
 skill_0[1]      1.00      0.00      1.00      1.00      1.00       nan       nan
 skill_0[2]      0.01      0.10      0.00      0.00      0.00   4093.80      1.00
 skill_0[3]      1.00      0.00      1.00      1.00      1.00       nan       nan
 skill_0[4]      1.00      0.00      1.00      1.00      1.00       nan       nan
 skill_0[5]      0.00      0.00      0.00      0.00      0.00       nan       nan
 skill_0[6]      1.00      0.00      1.00      1.00      1.00       nan       nan
 skill_0[7]      1.00      0.02      1.00      1.00      1.00       nan      1.00
 skill_0[8]      0.99      0.12      1.00      1.00      1.00   4130.52      1.00
 skill_0[9]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_0[10]      1.00      0.02      1.00      1.00      1.00       nan      1.00
skill_0[11]      0.97      0.18      1.00      1.00      1.00   4311.42      1.00
skill_0[12]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_0[13]      0.66      0.47      1.00      0.00      1.00  12611.01      1.00
skill_0[14]      1.00      0.02      1.00      1.00      1.00       nan      1.00
skill_0[15]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_0[16]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_0[17]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_0[18]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_0[19]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_0[20]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_0[21]      1.00      0.00      1.00      1.00      1.00       nan       nan
 skill_1[0]      1.00      0.00      1.00      1.00      1.00       nan       nan
 skill_1[1]      1.00      0.07      1.00      1.00      1.00   4043.37      1.00
 skill_1[2]      0.00      0.07      0.00      0.00      0.00   3397.09      1.00
 skill_1[3]      1.00      0.00      1.00      1.00      1.00       nan       nan
 skill_1[4]      1.00      0.00      1.00      1.00      1.00       nan       nan
 skill_1[5]      0.00      0.06      0.00      0.00      0.00   4039.58      1.00
 skill_1[6]      1.00      0.00      1.00      1.00      1.00       nan       nan
 skill_1[7]      1.00      0.00      1.00      1.00      1.00       nan       nan
 skill_1[8]      1.00      0.00      1.00      1.00      1.00       nan       nan
 skill_1[9]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_1[10]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_1[11]      0.94      0.24      1.00      1.00      1.00   3595.70      1.00
skill_1[12]      1.00      0.02      1.00      1.00      1.00       nan      1.00
skill_1[13]      1.00      0.02      1.00      1.00      1.00       nan      1.00
skill_1[14]      0.97      0.17      1.00      1.00      1.00   3986.16      1.00
skill_1[15]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_1[16]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_1[17]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_1[18]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_1[19]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_1[20]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_1[21]      1.00      0.00      1.00      1.00      1.00       nan       nan
 skill_2[0]      0.98      0.15      1.00      1.00      1.00   4195.02      1.00
 skill_2[1]      0.59      0.49      1.00      0.00      1.00  27054.31      1.00
 skill_2[2]      0.04      0.20      0.00      0.00      0.00   4370.98      1.00
 skill_2[3]      0.58      0.49      1.00      0.00      1.00  27689.66      1.00
 skill_2[4]      0.98      0.13      1.00      1.00      1.00   4140.87      1.00
 skill_2[5]      0.00      0.00      0.00      0.00      0.00       nan       nan
 skill_2[6]      0.98      0.14      1.00      1.00      1.00   3888.96      1.00
 skill_2[7]      1.00      0.02      1.00      1.00      1.00       nan      1.00
 skill_2[8]      0.98      0.13      1.00      1.00      1.00   4056.76      1.00
 skill_2[9]      0.04      0.20      0.00      0.00      0.00   4393.83      1.00
skill_2[10]      0.58      0.49      1.00      0.00      1.00  29403.65      1.00
skill_2[11]      0.98      0.14      1.00      1.00      1.00   3875.56      1.00
skill_2[12]      0.98      0.14      1.00      1.00      1.00   4166.72      1.00
skill_2[13]      0.98      0.15      1.00      1.00      1.00   4175.88      1.00
skill_2[14]      0.58      0.49      1.00      0.00      1.00  37696.08      1.00
skill_2[15]      0.59      0.49      1.00      0.00      1.00  24106.45      1.00
skill_2[16]      0.58      0.49      1.00      0.00      1.00  31530.27      1.00
skill_2[17]      0.98      0.14      1.00      1.00      1.00   4186.04      1.00
skill_2[18]      0.98      0.14      1.00      1.00      1.00   4171.18      1.00
skill_2[19]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_2[20]      0.98      0.13      1.00      1.00      1.00   4110.53      1.00
skill_2[21]      0.58      0.49      1.00      0.00      1.00  24811.19      1.00
 skill_3[0]      0.99      0.08      1.00      1.00      1.00   3895.34      1.00
 skill_3[1]      0.99      0.08      1.00      1.00      1.00   4060.72      1.00
 skill_3[2]      0.99      0.09      1.00      1.00      1.00   3623.58      1.00
 skill_3[3]      0.99      0.09      1.00      1.00      1.00   3235.95      1.00
 skill_3[4]      0.78      0.41      1.00      0.00      1.00   6003.05      1.00
 skill_3[5]      0.00      0.00      0.00      0.00      0.00       nan       nan
 skill_3[6]      0.99      0.10      1.00      1.00      1.00   4033.33      1.00
 skill_3[7]      0.78      0.42      1.00      0.00      1.00   7574.41      1.00
 skill_3[8]      1.00      0.00      1.00      1.00      1.00       nan       nan
 skill_3[9]      1.00      0.02      1.00      1.00      1.00       nan      1.00
skill_3[10]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_3[11]      0.00      0.04      0.00      0.00      0.00   4022.87      1.00
skill_3[12]      0.99      0.09      1.00      1.00      1.00   3973.96      1.00
skill_3[13]      0.99      0.09      1.00      1.00      1.00   3833.68      1.00
skill_3[14]      0.99      0.09      1.00      1.00      1.00   4075.71      1.00
skill_3[15]      1.00      0.02      1.00      1.00      1.00       nan      1.00
skill_3[16]      0.99      0.08      1.00      1.00      1.00   4055.73      1.00
skill_3[17]      0.99      0.09      1.00      1.00      1.00   4076.82      1.00
skill_3[18]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_3[19]      1.00      0.02      1.00      1.00      1.00       nan      1.00
skill_3[20]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_3[21]      1.00      0.00      1.00      1.00      1.00       nan       nan
 skill_4[0]      1.00      0.02      1.00      1.00      1.00       nan      1.00
 skill_4[1]      0.15      0.35      0.00      0.00      1.00   5860.67      1.00
 skill_4[2]      0.59      0.49      1.00      0.00      1.00  21820.18      1.00
 skill_4[3]      0.15      0.36      0.00      0.00      1.00   5949.44      1.00
 skill_4[4]      1.00      0.00      1.00      1.00      1.00       nan       nan
 skill_4[5]      0.00      0.03      0.00      0.00      0.00       nan      1.00
 skill_4[6]      1.00      0.06      1.00      1.00      1.00   4036.51      1.00
 skill_4[7]      0.15      0.36      0.00      0.00      1.00   5670.10      1.00
 skill_4[8]      1.00      0.00      1.00      1.00      1.00       nan       nan
 skill_4[9]      1.00      0.02      1.00      1.00      1.00       nan      1.00
skill_4[10]      0.86      0.35      1.00      0.00      1.00   5786.31      1.00
skill_4[11]      1.00      0.07      1.00      1.00      1.00   3728.74      1.00
skill_4[12]      0.87      0.34      1.00      0.00      1.00   4455.87      1.00
skill_4[13]      0.15      0.36      0.00      0.00      1.00   5889.51      1.00
skill_4[14]      0.86      0.34      1.00      0.00      1.00   5095.11      1.00
skill_4[15]      0.86      0.35      1.00      0.00      1.00   5561.46      1.00
skill_4[16]      1.00      0.07      1.00      1.00      1.00   4041.21      1.00
skill_4[17]      1.00      0.06      1.00      1.00      1.00   4035.86      1.00
skill_4[18]      0.86      0.34      1.00      0.00      1.00   5361.01      1.00
skill_4[19]      0.99      0.08      1.00      1.00      1.00   4053.58      1.00
skill_4[20]      1.00      0.02      1.00      1.00      1.00       nan      1.00
skill_4[21]      1.00      0.06      1.00      1.00      1.00   4035.14      1.00
 skill_5[0]      1.00      0.00      1.00      1.00      1.00       nan       nan
 skill_5[1]      0.00      0.00      0.00      0.00      0.00       nan       nan
 skill_5[2]      0.78      0.42      1.00      0.00      1.00   7042.74      1.00
 skill_5[3]      0.99      0.09      1.00      1.00      1.00   4074.01      1.00
 skill_5[4]      0.99      0.08      1.00      1.00      1.00   3818.77      1.00
 skill_5[5]      0.00      0.00      0.00      0.00      0.00       nan       nan
 skill_5[6]      1.00      0.02      1.00      1.00      1.00       nan      1.00
 skill_5[7]      1.00      0.02      1.00      1.00      1.00       nan      1.00
 skill_5[8]      1.00      0.02      1.00      1.00      1.00       nan      1.00
 skill_5[9]      0.99      0.08      1.00      1.00      1.00   3627.43      1.00
skill_5[10]      0.78      0.41      1.00      0.00      1.00   7519.46      1.00
skill_5[11]      0.00      0.02      0.00      0.00      0.00       nan      1.00
skill_5[12]      1.00      0.02      1.00      1.00      1.00       nan      1.00
skill_5[13]      0.78      0.42      1.00      0.00      1.00   7063.90      1.00
skill_5[14]      1.00      0.02      1.00      1.00      1.00       nan      1.00
skill_5[15]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_5[16]      1.00      0.02      1.00      1.00      1.00       nan      1.00
skill_5[17]      0.78      0.41      1.00      0.00      1.00   7747.84      1.00
skill_5[18]      0.78      0.42      1.00      0.00      1.00   6926.57      1.00
skill_5[19]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_5[20]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_5[21]      1.00      0.00      1.00      1.00      1.00       nan       nan
 skill_6[0]      1.00      0.00      1.00      1.00      1.00       nan       nan
 skill_6[1]      0.99      0.10      1.00      1.00      1.00   4020.57      1.00
 skill_6[2]      0.50      0.50      0.00      0.00      1.00  -7660.98      1.00
 skill_6[3]      1.00      0.00      1.00      1.00      1.00       nan       nan
 skill_6[4]      1.00      0.00      1.00      1.00      1.00       nan       nan
 skill_6[5]      0.50      0.50      0.00      0.00      1.00  -4056.58      1.00
 skill_6[6]      1.00      0.00      1.00      1.00      1.00       nan       nan
 skill_6[7]      1.00      0.02      1.00      1.00      1.00       nan      1.00
 skill_6[8]      1.00      0.00      1.00      1.00      1.00       nan       nan
 skill_6[9]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_6[10]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_6[11]      0.09      0.29      0.00      0.00      0.00   2600.06      1.00
skill_6[12]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_6[13]      0.99      0.09      1.00      1.00      1.00   4078.84      1.00
skill_6[14]      0.99      0.11      1.00      1.00      1.00   4109.17      1.00
skill_6[15]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_6[16]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_6[17]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_6[18]      0.99      0.12      1.00      1.00      1.00   4128.28      1.00
skill_6[19]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_6[20]      1.00      0.00      1.00      1.00      1.00       nan       nan
skill_6[21]      1.00      0.02      1.00      1.00      1.00       nan      1.00

In [11]:
ds = az.from_numpyro(mcmc_00)
In [12]:
az.plot_trace(ds);
In [13]:
log_density_model_00, model_00_trace = log_density(
    model_00,
    (jnp.array(responses), skills_needed),
    dict(prob_mistake=0.1, prob_guess=0.2),
    {key: value.mean(0) for key, value in mcmc_00.get_samples().items()},
)
In [14]:
pe_model_00 = mcmc_00.get_extra_fields()["hmc_state.potential_energy"]
In [15]:
exp_log_density_00 = log_ppd(
    model_00, mcmc_00.get_samples(), jnp.array(responses), skills_needed
)
In [16]:
# post_loglik_00 = log_likelihood(
#     model_00, mcmc_00.get_samples(), jnp.array(responses), skills_needed,
# )
# post_loglik_00_res = np.concatenate(
#     [obs[:, None] for obs in post_loglik_00.values()], axis=1
# )
# exp_log_density_00 = logsumexp(post_loglik_00_res, axis=0) - jnp.log(
#     jnp.shape(post_loglik_00_res)[0]
# )
In [17]:
theta_model_00 = np.zeros((22, 7))
for i, param in enumerate(["skill_" + str(i) for i in range(7)]):
    theta_model_00[:, i] = np.mean(mcmc_00.get_samples()[param], axis=0)

neg_log_proba_model_00 = neg_log_proba_score(theta_model_00, self_assessed.values)
In [18]:
def model_02(
    graded_responses, skills_needed: List[List[int]], prob_mistake=0.1,
):
    n_questions, n_participants = graded_responses.shape
    n_skills = max(map(max, skills_needed)) + 1

    with numpyro.plate("questions_plate", n_questions):
        prob_guess = numpyro.sample("prob_guess", dist.Beta(2.5, 7.5))

    participants_plate = numpyro.plate("participants_plate", n_participants)

    with participants_plate:
        skills = []
        for s in range(n_skills):
            skills.append(numpyro.sample("skill_{}".format(s), dist.Bernoulli(0.5)))

    for q in range(n_questions):
        has_skills = reduce(operator.mul, [skills[i] for i in skills_needed[q]])
        prob_correct = (
            has_skills * (1 - prob_mistake) + (1 - has_skills) * prob_guess[q]
        )
        isCorrect = numpyro.sample(
            "isCorrect_{}".format(q),
            dist.Bernoulli(prob_correct).to_event(1),
            obs=graded_responses[q],
        )
In [19]:
nuts_kernel = NUTS(model_02)

kernel = DiscreteHMCGibbs(nuts_kernel, modified=True)

mcmc_02 = MCMC(kernel, num_warmup=200, num_samples=1000, num_chains=4)
mcmc_02.run(
    rng_key,
    jnp.array(responses),
    skills_needed,
    extra_fields=(
        "z",
        "hmc_state.potential_energy",
        "hmc_state.z",
        "rng_key",
        "hmc_state.rng_key",
    ),
)
mcmc_02.print_summary()
/home/benda/anaconda3/envs/numpyro_play/lib/python3.8/site-packages/numpyro/infer/mcmc.py:269: UserWarning: There are not enough devices to run parallel chains: expected 4 but got 1. Chains will be drawn sequentially. If you are running MCMC in CPU, consider using `numpyro.set_host_device_count(4)` at the beginning of your program. You can double-check how many devices are available in your system using `jax.local_device_count()`.
  warnings.warn(
sample: 100%|██████████| 1200/1200 [02:19<00:00,  8.60it/s, 7 steps of size 4.34e-01. acc. prob=0.88]  
sample: 100%|██████████| 1200/1200 [02:19<00:00,  8.59it/s, 7 steps of size 4.57e-01. acc. prob=0.86]  
sample: 100%|██████████| 1200/1200 [02:21<00:00,  8.47it/s, 7 steps of size 4.90e-01. acc. prob=0.86]  
sample: 100%|██████████| 1200/1200 [02:23<00:00,  8.35it/s, 7 steps of size 4.47e-01. acc. prob=0.86]  
                    mean       std    median      5.0%     95.0%     n_eff     r_hat
 prob_guess[0]      0.26      0.12      0.25      0.05      0.43   3138.97      1.00
 prob_guess[1]      0.28      0.12      0.27      0.08      0.46   3689.75      1.00
 prob_guess[2]      0.32      0.13      0.31      0.10      0.51   5050.20      1.00
 prob_guess[3]      0.33      0.13      0.32      0.11      0.53   3007.40      1.00
 prob_guess[4]      0.29      0.12      0.28      0.09      0.48   2769.02      1.00
 prob_guess[5]      0.20      0.11      0.19      0.04      0.38   5336.04      1.00
 prob_guess[6]      0.39      0.13      0.39      0.18      0.60   4493.49      1.00
 prob_guess[7]      0.38      0.13      0.37      0.16      0.59   5009.81      1.00
 prob_guess[8]      0.39      0.13      0.38      0.17      0.59   3560.31      1.00
 prob_guess[9]      0.32      0.13      0.31      0.10      0.52   3133.14      1.00
prob_guess[10]      0.30      0.13      0.29      0.09      0.49   3694.41      1.00
prob_guess[11]      0.19      0.11      0.18      0.03      0.35   4524.02      1.00
prob_guess[12]      0.23      0.12      0.22      0.04      0.41   3143.92      1.00
prob_guess[13]      0.20      0.11      0.19      0.03      0.36   5236.74      1.00
prob_guess[14]      0.22      0.11      0.21      0.04      0.39   3761.43      1.00
prob_guess[15]      0.35      0.12      0.34      0.13      0.54   3402.84      1.00
prob_guess[16]      0.31      0.12      0.30      0.10      0.49   4800.58      1.00
prob_guess[17]      0.24      0.12      0.23      0.05      0.43   4647.01      1.00
prob_guess[18]      0.61      0.11      0.61      0.44      0.78   3861.77      1.00
prob_guess[19]      0.56      0.11      0.56      0.38      0.73   4028.99      1.00
prob_guess[20]      0.51      0.11      0.51      0.33      0.70   3595.62      1.00
prob_guess[21]      0.11      0.07      0.10      0.01      0.21   4919.36      1.00
prob_guess[22]      0.44      0.11      0.44      0.27      0.63   3543.74      1.00
prob_guess[23]      0.28      0.13      0.27      0.07      0.47   3130.95      1.00
prob_guess[24]      0.26      0.12      0.25      0.07      0.46   2904.73      1.00
prob_guess[25]      0.34      0.13      0.33      0.11      0.52   4288.83      1.00
prob_guess[26]      0.52      0.15      0.53      0.29      0.76    958.60      1.00
prob_guess[27]      0.52      0.14      0.54      0.29      0.76   1093.10      1.00
prob_guess[28]      0.52      0.15      0.53      0.30      0.77   1114.17      1.00
prob_guess[29]      0.20      0.10      0.18      0.04      0.34   1932.30      1.00
prob_guess[30]      0.23      0.11      0.22      0.05      0.40   3297.66      1.00
prob_guess[31]      0.31      0.13      0.31      0.07      0.51   1391.98      1.00
prob_guess[32]      0.45      0.15      0.46      0.21      0.70   1048.96      1.01
prob_guess[33]      0.47      0.15      0.48      0.21      0.71   1007.86      1.00
prob_guess[34]      0.33      0.11      0.32      0.15      0.52   3329.14      1.00
prob_guess[35]      0.41      0.12      0.41      0.23      0.61   3670.40      1.00
prob_guess[36]      0.38      0.12      0.38      0.19      0.57   4165.09      1.00
prob_guess[37]      0.53      0.12      0.53      0.33      0.73   4196.24      1.00
prob_guess[38]      0.46      0.12      0.46      0.26      0.66   3673.99      1.00
prob_guess[39]      0.23      0.10      0.22      0.07      0.40   3564.23      1.00
prob_guess[40]      0.43      0.13      0.43      0.22      0.65   3616.51      1.00
prob_guess[41]      0.43      0.13      0.43      0.23      0.64   3123.85      1.00
prob_guess[42]      0.37      0.13      0.37      0.16      0.58   2582.66      1.00
prob_guess[43]      0.30      0.12      0.29      0.10      0.48   4110.90      1.00
prob_guess[44]      0.30      0.11      0.29      0.12      0.48   4875.23      1.00
prob_guess[45]      0.31      0.11      0.30      0.12      0.49   3179.36      1.00
prob_guess[46]      0.19      0.10      0.17      0.03      0.33   4516.82      1.00
prob_guess[47]      0.31      0.12      0.30      0.13      0.51   5223.94      1.00
    skill_0[0]      1.00      0.00      1.00      1.00      1.00       nan       nan
    skill_0[1]      1.00      0.03      1.00      1.00      1.00       nan      1.00
    skill_0[2]      0.01      0.10      0.00      0.00      0.00   3876.81      1.00
    skill_0[3]      1.00      0.04      1.00      1.00      1.00       nan      1.00
    skill_0[4]      1.00      0.00      1.00      1.00      1.00       nan       nan
    skill_0[5]      0.00      0.00      0.00      0.00      0.00       nan       nan
    skill_0[6]      1.00      0.00      1.00      1.00      1.00       nan       nan
    skill_0[7]      1.00      0.07      1.00      1.00      1.00   4041.91      1.00
    skill_0[8]      0.92      0.27      1.00      1.00      1.00   3931.25      1.00
    skill_0[9]      1.00      0.00      1.00      1.00      1.00       nan       nan
   skill_0[10]      0.99      0.11      1.00      1.00      1.00   4110.15      1.00
   skill_0[11]      0.89      0.31      1.00      0.00      1.00   4984.72      1.00
   skill_0[12]      1.00      0.03      1.00      1.00      1.00       nan      1.00
   skill_0[13]      0.28      0.45      0.00      0.00      1.00   3007.28      1.00
   skill_0[14]      0.96      0.19      1.00      1.00      1.00   3995.70      1.00
   skill_0[15]      1.00      0.00      1.00      1.00      1.00       nan       nan
   skill_0[16]      1.00      0.00      1.00      1.00      1.00       nan       nan
   skill_0[17]      1.00      0.03      1.00      1.00      1.00       nan      1.00
   skill_0[18]      0.99      0.09      1.00      1.00      1.00   3971.27      1.00
   skill_0[19]      1.00      0.00      1.00      1.00      1.00       nan       nan
   skill_0[20]      1.00      0.00      1.00      1.00      1.00       nan       nan
   skill_0[21]      1.00      0.02      1.00      1.00      1.00       nan      1.00
    skill_1[0]      1.00      0.00      1.00      1.00      1.00       nan       nan
    skill_1[1]      0.98      0.15      1.00      1.00      1.00   3746.90      1.00
    skill_1[2]      0.00      0.07      0.00      0.00      0.00   4040.30      1.00
    skill_1[3]      1.00      0.02      1.00      1.00      1.00       nan      1.00
    skill_1[4]      1.00      0.00      1.00      1.00      1.00       nan       nan
    skill_1[5]      0.00      0.06      0.00      0.00      0.00   4033.86      1.00
    skill_1[6]      1.00      0.02      1.00      1.00      1.00       nan      1.00
    skill_1[7]      1.00      0.00      1.00      1.00      1.00       nan       nan
    skill_1[8]      1.00      0.00      1.00      1.00      1.00       nan       nan
    skill_1[9]      1.00      0.00      1.00      1.00      1.00       nan       nan
   skill_1[10]      1.00      0.00      1.00      1.00      1.00       nan       nan
   skill_1[11]      0.81      0.39      1.00      0.00      1.00   1393.43      1.00
   skill_1[12]      1.00      0.02      1.00      1.00      1.00       nan      1.00
   skill_1[13]      1.00      0.05      1.00      1.00      1.00   4029.22      1.00
   skill_1[14]      0.71      0.45      1.00      0.00      1.00   3865.86      1.00
   skill_1[15]      1.00      0.00      1.00      1.00      1.00       nan       nan
   skill_1[16]      1.00      0.00      1.00      1.00      1.00       nan       nan
   skill_1[17]      1.00      0.00      1.00      1.00      1.00       nan       nan
   skill_1[18]      1.00      0.04      1.00      1.00      1.00       nan      1.00
   skill_1[19]      1.00      0.00      1.00      1.00      1.00       nan       nan
   skill_1[20]      1.00      0.00      1.00      1.00      1.00       nan       nan
   skill_1[21]      1.00      0.07      1.00      1.00      1.00   4045.54      1.00
    skill_2[0]      0.51      0.50      1.00      0.00      1.00  13562.32      1.00
    skill_2[1]      0.08      0.28      0.00      0.00      0.00   4902.13      1.00
    skill_2[2]      0.01      0.11      0.00      0.00      0.00   3871.70      1.00
    skill_2[3]      0.12      0.32      0.00      0.00      1.00   5291.34      1.00
    skill_2[4]      0.51      0.50      1.00      0.00      1.00  12431.32      1.00
    skill_2[5]      0.00      0.03      0.00      0.00      0.00       nan      1.00
    skill_2[6]      0.51      0.50      1.00      0.00      1.00  20899.36      1.00
    skill_2[7]      0.99      0.12      1.00      1.00      1.00   4125.13      1.00
    skill_2[8]      0.52      0.50      1.00      0.00      1.00  14466.44      1.00
    skill_2[9]      0.02      0.14      0.00      0.00      0.00   4122.87      1.00
   skill_2[10]      0.09      0.29      0.00      0.00      0.00   4396.90      1.00
   skill_2[11]      0.51      0.50      1.00      0.00      1.00  18257.86      1.00
   skill_2[12]      0.51      0.50      1.00      0.00      1.00  13388.88      1.00
   skill_2[13]      0.89      0.32      1.00      0.00      1.00   5149.99      1.00
   skill_2[14]      0.14      0.34      0.00      0.00      1.00   5471.75      1.00
   skill_2[15]      0.09      0.28      0.00      0.00      0.00   4937.10      1.00
   skill_2[16]      0.18      0.38      0.00      0.00      1.00   5614.03      1.00
   skill_2[17]      0.51      0.50      1.00      0.00      1.00  23114.37      1.00
   skill_2[18]      0.51      0.50      1.00      0.00      1.00  20113.88      1.00
   skill_2[19]      0.99      0.12      1.00      1.00      1.00   4116.94      1.00
   skill_2[20]      0.51      0.50      1.00      0.00      1.00  15794.23      1.00
   skill_2[21]      0.10      0.30      0.00      0.00      0.00   4538.81      1.00
    skill_3[0]      0.52      0.50      1.00      0.00      1.00   1627.99      1.00
    skill_3[1]      0.62      0.49      1.00      0.00      1.00   2340.25      1.00
    skill_3[2]      0.61      0.49      1.00      0.00      1.00   2249.26      1.00
    skill_3[3]      0.73      0.44      1.00      0.00      1.00   2846.19      1.00
    skill_3[4]      0.11      0.31      0.00      0.00      1.00   1697.16      1.00
    skill_3[5]      0.00      0.02      0.00      0.00      0.00       nan      1.00
    skill_3[6]      0.52      0.50      1.00      0.00      1.00   1757.86      1.00
    skill_3[7]      0.11      0.31      0.00      0.00      1.00   1598.81      1.00
    skill_3[8]      0.94      0.24      1.00      1.00      1.00   3779.98      1.00
    skill_3[9]      0.94      0.23      1.00      1.00      1.00   3426.03      1.00
   skill_3[10]      0.94      0.24      1.00      1.00      1.00   3526.44      1.00
   skill_3[11]      0.00      0.03      0.00      0.00      0.00       nan      1.00
   skill_3[12]      0.52      0.50      1.00      0.00      1.00   1658.32      1.00
   skill_3[13]      0.61      0.49      1.00      0.00      1.00   2149.84      1.00
   skill_3[14]      0.52      0.50      1.00      0.00      1.00   1790.86      1.00
   skill_3[15]      1.00      0.03      1.00      1.00      1.00       nan      1.00
   skill_3[16]      0.52      0.50      1.00      0.00      1.00   1769.12      1.00
   skill_3[17]      0.52      0.50      1.00      0.00      1.00   1683.65      1.00
   skill_3[18]      0.94      0.23      1.00      1.00      1.00   3907.40      1.00
   skill_3[19]      0.94      0.23      1.00      1.00      1.00   4002.44      1.00
   skill_3[20]      0.95      0.23      1.00      1.00      1.00   3661.50      1.00
   skill_3[21]      0.94      0.23      1.00      1.00      1.00   3765.80      1.00
    skill_4[0]      0.99      0.08      1.00      1.00      1.00   4060.42      1.00
    skill_4[1]      0.09      0.29      0.00      0.00      0.00   4379.23      1.00
    skill_4[2]      0.49      0.50      0.00      0.00      1.00  10748.09      1.00
    skill_4[3]      0.04      0.20      0.00      0.00      0.00   4252.62      1.00
    skill_4[4]      0.99      0.07      1.00      1.00      1.00   4049.81      1.00
    skill_4[5]      0.00      0.03      0.00      0.00      0.00       nan      1.00
    skill_4[6]      0.91      0.29      1.00      1.00      1.00   4622.34      1.00
    skill_4[7]      0.07      0.26      0.00      0.00      0.00   4727.83      1.00
    skill_4[8]      0.99      0.07      1.00      1.00      1.00   3776.96      1.00
    skill_4[9]      0.99      0.09      1.00      1.00      1.00   4067.26      1.00
   skill_4[10]      0.39      0.49      0.00      0.00      1.00   8301.22      1.00
   skill_4[11]      0.89      0.31      1.00      0.00      1.00   4541.33      1.00
   skill_4[12]      0.38      0.49      0.00      0.00      1.00   7387.72      1.00
   skill_4[13]      0.06      0.23      0.00      0.00      0.00   4535.34      1.00
   skill_4[14]      0.27      0.44      0.00      0.00      1.00   5097.39      1.00
   skill_4[15]      0.31      0.46      0.00      0.00      1.00   7812.16      1.00
   skill_4[16]      0.87      0.34      1.00      0.00      1.00   5593.14      1.00
   skill_4[17]      0.93      0.26      1.00      1.00      1.00   4655.86      1.00
   skill_4[18]      0.31      0.46      0.00      0.00      1.00   7551.88      1.00
   skill_4[19]      0.91      0.29      1.00      1.00      1.00   4885.18      1.00
   skill_4[20]      0.99      0.08      1.00      1.00      1.00   4057.65      1.00
   skill_4[21]      0.91      0.28      1.00      1.00      1.00   4967.61      1.00
    skill_5[0]      1.00      0.02      1.00      1.00      1.00       nan      1.00
    skill_5[1]      0.00      0.00      0.00      0.00      0.00       nan       nan
    skill_5[2]      0.29      0.45      0.00      0.00      1.00   4111.53      1.00
    skill_5[3]      0.83      0.38      1.00      0.00      1.00   5250.81      1.00
    skill_5[4]      0.90      0.30      1.00      1.00      1.00   4413.22      1.00
    skill_5[5]      0.00      0.00      0.00      0.00      0.00       nan       nan
    skill_5[6]      1.00      0.02      1.00      1.00      1.00       nan      1.00
    skill_5[7]      0.99      0.07      1.00      1.00      1.00   4051.40      1.00
    skill_5[8]      1.00      0.05      1.00      1.00      1.00   4026.94      1.00
    skill_5[9]      0.83      0.37      1.00      0.00      1.00   5908.88      1.00
   skill_5[10]      0.43      0.50      0.00      0.00      1.00   6477.28      1.00
   skill_5[11]      0.00      0.02      0.00      0.00      0.00       nan      1.00
   skill_5[12]      0.99      0.11      1.00      1.00      1.00   4104.54      1.00
   skill_5[13]      0.30      0.46      0.00      0.00      1.00   4174.26      1.00
   skill_5[14]      0.99      0.10      1.00      1.00      1.00   4099.45      1.00
   skill_5[15]      1.00      0.02      1.00      1.00      1.00       nan      1.00
   skill_5[16]      0.99      0.10      1.00      1.00      1.00   4096.47      1.00
   skill_5[17]      0.28      0.45      0.00      0.00      1.00   4807.04      1.00
   skill_5[18]      0.29      0.45      0.00      0.00      1.00   4008.14      1.00
   skill_5[19]      1.00      0.00      1.00      1.00      1.00       nan       nan
   skill_5[20]      1.00      0.00      1.00      1.00      1.00       nan       nan
   skill_5[21]      1.00      0.02      1.00      1.00      1.00       nan      1.00
    skill_6[0]      1.00      0.00      1.00      1.00      1.00       nan       nan
    skill_6[1]      0.85      0.36      1.00      0.00      1.00   4567.60      1.00
    skill_6[2]      0.50      0.50      0.00      0.00      1.00  -9642.91      1.00
    skill_6[3]      1.00      0.02      1.00      1.00      1.00       nan      1.00
    skill_6[4]      1.00      0.00      1.00      1.00      1.00       nan       nan
    skill_6[5]      0.50      0.50      0.00      0.00      1.00  -4032.19      1.00
    skill_6[6]      1.00      0.04      1.00      1.00      1.00   4023.54      1.00
    skill_6[7]      0.99      0.08      1.00      1.00      1.00   3523.17      1.00
    skill_6[8]      1.00      0.00      1.00      1.00      1.00       nan       nan
    skill_6[9]      1.00      0.02      1.00      1.00      1.00       nan      1.00
   skill_6[10]      1.00      0.00      1.00      1.00      1.00       nan       nan
   skill_6[11]      0.11      0.31      0.00      0.00      1.00   1995.29      1.00
   skill_6[12]      1.00      0.02      1.00      1.00      1.00       nan      1.00
   skill_6[13]      0.84      0.36      1.00      0.00      1.00   4683.67      1.00
   skill_6[14]      0.82      0.39      1.00      0.00      1.00   3364.55      1.00
   skill_6[15]      1.00      0.02      1.00      1.00      1.00       nan      1.00
   skill_6[16]      1.00      0.00      1.00      1.00      1.00       nan       nan
   skill_6[17]      1.00      0.00      1.00      1.00      1.00       nan       nan
   skill_6[18]      0.87      0.33      1.00      0.00      1.00   4571.07      1.00
   skill_6[19]      1.00      0.00      1.00      1.00      1.00       nan       nan
   skill_6[20]      1.00      0.02      1.00      1.00      1.00       nan      1.00
   skill_6[21]      0.99      0.12      1.00      1.00      1.00   3642.66      1.00

In [20]:
ds = az.from_numpyro(mcmc_02)
In [21]:
az.plot_trace(ds);