import pymc3 as pm
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
%qtconsole --colors=linux
plt.style.use('ggplot')
from matplotlib import gridspec
from theano import tensor as tt
from scipy import stats
The basic data for an SDT analysis are the counts of hits, false alarms, misses, and correct rejections. It is common to consider just the hit and false alarm counts which, together with the total number of signal and noise trials, completely describe the data.
Signal trial | Noise trial | |
---|---|---|
Yes response | Hit | False alarm |
No response | Miss | Correct rejection |
from matplotlib.patches import Polygon
func = stats.norm.pdf
a, b, d = 2, 7, 2.25 # integral limits 1
x = np.linspace(-4, b, 100)
y = func(x)
y2 = func(x, loc=d)
fig, ax = plt.subplots(figsize=(15, 6))
plt.plot(x, y, 'k', linewidth=2)
plt.plot(x, y2, 'k', linewidth=2)
d2 = x[y==y2]
# Make the shaded region
ix = np.linspace(a, b)
iy = func(ix)
iy2 = func(ix, loc=d)
verts = [(a, 0)] + list(zip(ix, iy2)) + [(b, 0)]
poly = Polygon(verts, facecolor='.8', edgecolor='1')
ax.add_patch(poly)
verts = [(a, 0)] + list(zip(ix, iy)) + [(b, 0)]
poly = Polygon(verts, facecolor='.5', edgecolor='0.5')
ax.add_patch(poly)
plt.text(0, y.max()+.02, "noise",
horizontalalignment='center', fontsize=20)
plt.text(d, y.max()+.02, "signal",
horizontalalignment='center', fontsize=20)
plt.text((d/2+a)/2, y.max()*.08, r"$ - \,c - $",
horizontalalignment='center', fontsize=20)
plt.text(b*.7, y.max()*.5, r"$\theta^h$",
horizontalalignment='center', fontsize=20)
plt.text(b*.7, y.max()*.2, r"$\theta^f$",
horizontalalignment='center', fontsize=20)
plt.plot([d+1, b*.7-.25], [y.max()*.4, y.max()*.5+.01], color='k', linestyle='-', linewidth=1.5)
plt.plot([d+.05, b*.7-.25], [y.max()*.01, y.max()*.2+.01], color='k', linestyle='-', linewidth=1.5)
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.xaxis.set_ticks_position('bottom')
plt.plot([0, 0], [0, func(0)], color='k', linestyle='--', linewidth=2)
plt.plot([d/2, d/2], [0, func(d/2)], color='k', linestyle='--', linewidth=2)
plt.plot([a, a], [0, func(a, loc=d)], color='k', linestyle='--', linewidth=2)
plt.plot([d, d], [0, func(d, loc=d)], color='k', linestyle='--', linewidth=2)
ax.set_xticks((0, a, d, d/2))
ax.set_xticklabels(('$0$', '$k$', '$d$', '$d/2$'), fontsize=18)
ax.set_yticks([])
plt.ylim(0, y.max()+.1)
plt.title('Equal-variance Gaussian signal detection theory framework', fontsize=20)
plt.xlabel('Strength', fontsize=20)
plt.show()
# Load data
dataset = 1
if dataset==1: # Demo
k = 3 # number of cases
data =np.array([70, 50, 30, 50,
7, 5, 3, 5,
10, 0, 0, 10]).reshape(k, -1)
else: # Lehrner et al. (1995) data
k = 3 # number of cases
data =np.array([148, 29, 32, 151,
150, 40, 30, 140,
150, 51, 40, 139]).reshape(k, -1)
h = data[:, 0]
f = data[:, 1]
MI = data[:, 2]
CR = data[:, 3]
s = h + MI
n = f + CR
def Phi(x):
#'Cumulative distribution function for the standard normal distribution'
# Also it is the probit transform
return 0.5 + 0.5 * pm.math.erf(x/pm.math.sqrt(2))
with pm.Model() as model1:
di = pm.Normal('Discriminability', mu=0, tau=.5, shape= k)
ci = pm.Normal('Bias', mu=0, tau=2, shape= k)
thetah = pm.Deterministic('Hit Rate', Phi(di/2-ci))
thetaf = pm.Deterministic('False Alarm Rate', Phi(-di/2-ci))
hi = pm.Binomial('hi',p=thetah,n=s,observed=h)
fi = pm.Binomial('fi',p=thetaf,n=n,observed=f)
trace1=pm.sample(1e4, njobs=2)
burnin=0
pm.traceplot(trace1[burnin:],
varnames=['Discriminability', 'Bias', 'Hit Rate', 'False Alarm Rate']);
Auto-assigning NUTS sampler... Initializing NUTS using jitter+adapt_diag... 100%|██████████| 10500/10500.0 [00:10<00:00, 977.34it/s]
# Load data using rpy2
from rpy2 import *
%load_ext rpy2.ipython
%R source("data/heit_rotello.RData") -o std_i -o std_d
# the induction data and the deduction data
h1 = np.array(std_i['V1'])
f1 = np.array(std_i['V2'])
MI1 = np.array(std_i['V3'])
CR1 = np.array(std_i['V4'])
s1 = h1 + MI1
n1 = f1 + CR1
h2 = np.array(std_d['V1'])
f2 = np.array(std_d['V2'])
MI2 = np.array(std_d['V3'])
CR2 = np.array(std_d['V4'])
s2 = h2 + MI2
n2 = f2 + CR2
k=len(h1)
with pm.Model() as model2i:
mud = pm.Normal('mud', mu=0, tau=.001)
muc = pm.Normal('muc', mu=0, tau=.001)
lambdad = pm.Gamma('lambdad', alpha=.001, beta=.001)
lambdac = pm.Gamma('lambdac', alpha=.001, beta=.001)
di = pm.Normal('di', mu=mud, tau=lambdad, shape= k)
ci = pm.Normal('ci', mu=muc, tau=lambdac, shape= k)
thetah = pm.Deterministic('Hit Rate', Phi(di/2-ci))
thetaf = pm.Deterministic('False Alarm Rate', Phi(-di/2-ci))
hi = pm.Binomial('hi', p=thetah, n=s1, observed=h1)
fi = pm.Binomial('fi', p=thetaf, n=n1, observed=f1)
trace_i=pm.sample(3e3, njobs=2)
with pm.Model() as model2d:
mud = pm.Normal('mud', mu=0, tau=.001)
muc = pm.Normal('muc', mu=0, tau=.001)
lambdad = pm.Gamma('lambdad', alpha=.001, beta=.001)
lambdac = pm.Gamma('lambdac', alpha=.001, beta=.001)
di = pm.Normal('di', mu=mud,tau=lambdad,shape= k)
ci = pm.Normal('ci', mu=muc,tau=lambdac, shape= k)
thetah = pm.Deterministic('Hit Rate', Phi(di/2-ci))
thetaf = pm.Deterministic('False Alarm Rate', Phi(-di/2-ci))
hi = pm.Binomial('hi', p=thetah, n=s2, observed=h2)
fi = pm.Binomial('fi', p=thetaf, n=n2, observed=f2)
trace_d=pm.sample(3e3, njobs=2)
Auto-assigning NUTS sampler... Initializing NUTS using jitter+adapt_diag... 99%|█████████▉| 3476/3500.0 [00:10<00:00, 377.28it/s]/home/laoj/Documents/Github/pymc3/pymc3/step_methods/hmc/nuts.py:467: UserWarning: Chain 0 contains 27 diverging samples after tuning. If increasing `target_accept` does not help try to reparameterize. % (self._chain_id, n_diverging)) 100%|██████████| 3500/3500.0 [00:10<00:00, 341.94it/s] /home/laoj/Documents/Github/pymc3/pymc3/step_methods/hmc/nuts.py:467: UserWarning: Chain 1 contains 18 diverging samples after tuning. If increasing `target_accept` does not help try to reparameterize. % (self._chain_id, n_diverging)) Auto-assigning NUTS sampler... Initializing NUTS using jitter+adapt_diag... 62%|██████▏ | 2154/3500.0 [00:09<00:04, 297.69it/s]/home/laoj/Documents/Github/pymc3/pymc3/step_methods/hmc/nuts.py:467: UserWarning: Chain 1 contains 76 diverging samples after tuning. If increasing `target_accept` does not help try to reparameterize. % (self._chain_id, n_diverging)) 99%|█████████▉| 3476/3500.0 [00:14<00:00, 279.62it/s]/home/laoj/Documents/Github/pymc3/pymc3/step_methods/hmc/nuts.py:451: UserWarning: The acceptance probability in chain 0 does not match the target. It is 0.932983672035, but should be close to 0.8. Try to increase the number of tuning steps. % (self._chain_id, mean_accept, target_accept)) /home/laoj/Documents/Github/pymc3/pymc3/step_methods/hmc/nuts.py:467: UserWarning: Chain 0 contains 87 diverging samples after tuning. If increasing `target_accept` does not help try to reparameterize. % (self._chain_id, n_diverging)) 100%|██████████| 3500/3500.0 [00:14<00:00, 248.64it/s]
def scatterplot_2trace(trace1, trace2):
from matplotlib.ticker import NullFormatter
nullfmt = NullFormatter() # no labels
burnin=500 # set to zero to plot the one without burnin
# definitions for the axes
left, width = 0.1, 0.65
bottom, height = 0.1, 0.65
bottom_h = left_h = left + width + 0.02
rect_scatter = [left, bottom, width, height]
rect_histx = [left, bottom_h, width, 0.2]
rect_histy = [left_h, bottom, 0.2, height]
# now determine limits by hand:
binwidth1 = 0.25
# start with a rectangular Figure
plt.figure(1, figsize=(8, 8))
cc = np.array([[1,0,0],[0,0,1]])
for idd in np.arange(2):
if idd==0:
x = trace1['mud'][burnin:]
y = trace1['muc'][burnin:]
else:
x = trace2['mud'][burnin:]
y = trace2['muc'][burnin:]
axScatter = plt.axes(rect_scatter)
axScatter.set_xlim((-1, 6))
axScatter.set_ylim((-3, 3))
axHistx = plt.axes(rect_histx)
axHisty = plt.axes(rect_histy)
# no labels
axHistx.xaxis.set_major_formatter(nullfmt)
axHisty.yaxis.set_major_formatter(nullfmt)
# the scatter plot:
axScatter.scatter(x, y,c=cc[idd,:],alpha=.05)
axScatter.set_xlabel(r'$\mu_d$',fontsize=18)
axScatter.set_ylabel(r'$\mu_c$',fontsize=18)
bins1 = np.linspace(-1, 6, 50)
axHistx.hist(x, bins=bins1,color=cc[idd,:],alpha=.5,normed=True)
bins2 = np.linspace(-3, 3, 50)
axHisty.hist(y, bins=bins2, color=cc[idd,:],alpha=.5,normed=True, orientation='horizontal')
axHistx.set_xlim(axScatter.get_xlim())
axHisty.set_ylim(axScatter.get_ylim())
plt.show()
scatterplot_2trace(trace_i, trace_d)
tmptrace = trace_i['lambdac'][:]
plt.figure(figsize=(15, 6))
gs = gridspec.GridSpec(2,1)
plt.subplot(gs[0])
plt.plot(1/tmptrace**2)
plt.ylim(0,2)
plt.title('Induction Condition')
plt.ylabel(r'$\sigma_c$',fontsize=18)
plt.subplot(gs[1])
tmptrace = trace_d['lambdac'][:]
plt.plot(1/tmptrace**2)
plt.ylim(0,2)
plt.title('Deduction Condition')
plt.ylabel(r'$\sigma_c$',fontsize=18)
plt.show()
Using parameter expansion to escape the "zero variance trap" in MCMC sampling.
$$ \mu_{d},\mu_{c} \sim \text{Gaussian}(0,.001)$$$$ \lambda_{d},\lambda_{c} \sim \text{Gamma}(.001,.001)$$$$ \xi_{d},\xi_{c} \sim \text{Beta}(1,1)$$$$ \delta_{d_{i}} \sim \text{Gaussian}(0,\lambda_{d})$$$$ \delta_{c_{i}} \sim \text{Gaussian}(0,\lambda_{c})$$$$ \sigma_{d} = \lvert \xi_{d}\rvert\,/\,\sqrt\lambda_{d}$$$$ \sigma_{c} = \lvert \xi_{c}\rvert\,/\,\sqrt\lambda_{c}$$$$ d_{i} = \mu_{d} + \xi_{d}\delta_{d_{i}}$$$$ c_{i} = \mu_{c} + \xi_{c}\delta_{c_{i}}$$$$ \theta_{i}^h = \Phi(\frac{1}{2}d_{i}-c_{i})$$$$ \theta_{i}^f = \Phi(-\frac{1}{2}d_{i}-c_{i})$$$$ h_{i} \sim \text{Binomial}(\theta_{i}^h,s_{i})$$$$ f_{i} \sim \text{Binomial}(\theta_{i}^f,n_{i})$$with pm.Model() as model3i:
mud = pm.Normal('mud', mu=0, tau=.001)
muc = pm.Normal('muc', mu=0, tau=.001)
lambdad = pm.Gamma('lambdad', alpha=.001, beta=.001)
lambdac = pm.Gamma('lambdac', alpha=.001, beta=.001)
xid = pm.Uniform('xid', lower=0, upper=1)
xic = pm.Uniform('xic', lower=0, upper=1)
deltadi = pm.Normal('deltadi', mu=0, tau=lambdad, shape=k)
deltaci = pm.Normal('deltaci', mu=0, tau=lambdac, shape=k)
sigmad = pm.Deterministic('sigmad', xid/tt.sqrt(lambdad))
sigmac = pm.Deterministic('sigmac', xic/tt.sqrt(lambdac))
di = pm.Deterministic('di', mud+xid*deltadi)
ci = pm.Deterministic('ci', muc+xic*deltaci)
thetah = pm.Deterministic('Hit Rate', Phi(di/2-ci))
thetaf = pm.Deterministic('False Alarm Rate', Phi(-di/2-ci))
hi = pm.Binomial('hi', p=thetah, n=s1, observed=h1)
fi = pm.Binomial('fi', p=thetaf, n=n1, observed=f1)
trace_i2=pm.sample(3e3, njobs=2)
with pm.Model() as model3d:
mud = pm.Normal('mud', mu=0, tau=.001)
muc = pm.Normal('muc', mu=0, tau=.001)
lambdad = pm.Gamma('lambdad', alpha=.001, beta=.001)
lambdac = pm.Gamma('lambdac', alpha=.001, beta=.001)
xid = pm.Uniform('xid', lower=0, upper=1)
xic = pm.Uniform('xic', lower=0, upper=1)
deltadi = pm.Normal('deltadi', mu=0, tau=lambdad, shape=k)
deltaci = pm.Normal('deltaci', mu=0, tau=lambdac, shape=k)
sigmad = pm.Deterministic('sigmad', xid/tt.sqrt(lambdad))
sigmac = pm.Deterministic('sigmac', xic/tt.sqrt(lambdac))
di = pm.Deterministic('di', mud+xid*deltadi)
ci = pm.Deterministic('ci', muc+xic*deltaci)
thetah = pm.Deterministic('Hit Rate', Phi(di/2-ci))
thetaf = pm.Deterministic('False Alarm Rate', Phi(-di/2-ci))
hi = pm.Binomial('hi', p=thetah, n=s2, observed=h2)
fi = pm.Binomial('fi', p=thetaf, n=n2, observed=f2)
trace_d2=pm.sample(3e3, njobs=2)
Auto-assigning NUTS sampler... Initializing NUTS using jitter+adapt_diag... 100%|██████████| 3500/3500.0 [00:42<00:00, 97.40it/s] /home/laoj/Documents/Github/pymc3/pymc3/step_methods/hmc/nuts.py:467: UserWarning: Chain 0 contains 148 diverging samples after tuning. If increasing `target_accept` does not help try to reparameterize. % (self._chain_id, n_diverging)) /home/laoj/Documents/Github/pymc3/pymc3/step_methods/hmc/nuts.py:459: UserWarning: Chain 1 reached the maximum tree depth. Increase max_treedepth, increase target_accept or reparameterize. 'reparameterize.' % self._chain_id) /home/laoj/Documents/Github/pymc3/pymc3/step_methods/hmc/nuts.py:451: UserWarning: The acceptance probability in chain 1 does not match the target. It is 0.713785426377, but should be close to 0.8. Try to increase the number of tuning steps. % (self._chain_id, mean_accept, target_accept)) /home/laoj/Documents/Github/pymc3/pymc3/step_methods/hmc/nuts.py:467: UserWarning: Chain 1 contains 274 diverging samples after tuning. If increasing `target_accept` does not help try to reparameterize. % (self._chain_id, n_diverging)) Auto-assigning NUTS sampler... Initializing NUTS using jitter+adapt_diag... 100%|█████████▉| 3498/3500.0 [00:22<00:00, 59.69it/s] /home/laoj/Documents/Github/pymc3/pymc3/step_methods/hmc/nuts.py:451: UserWarning: The acceptance probability in chain 0 does not match the target. It is 0.646569049019, but should be close to 0.8. Try to increase the number of tuning steps. % (self._chain_id, mean_accept, target_accept)) /home/laoj/Documents/Github/pymc3/pymc3/step_methods/hmc/nuts.py:467: UserWarning: Chain 0 contains 286 diverging samples after tuning. If increasing `target_accept` does not help try to reparameterize. % (self._chain_id, n_diverging)) 100%|██████████| 3500/3500.0 [00:22<00:00, 154.29it/s] /home/laoj/Documents/Github/pymc3/pymc3/step_methods/hmc/nuts.py:451: UserWarning: The acceptance probability in chain 1 does not match the target. It is 0.714662677258, but should be close to 0.8. Try to increase the number of tuning steps. % (self._chain_id, mean_accept, target_accept)) /home/laoj/Documents/Github/pymc3/pymc3/step_methods/hmc/nuts.py:467: UserWarning: Chain 1 contains 317 diverging samples after tuning. If increasing `target_accept` does not help try to reparameterize. % (self._chain_id, n_diverging))
scatterplot_2trace(trace_i2, trace_d2)
tmptrace = trace_i2['sigmac'][:]
plt.figure(figsize=(15, 6))
gs = gridspec.GridSpec(2,1)
plt.subplot(gs[0])
plt.plot(tmptrace)
plt.title('Induction Condition')
plt.ylabel(r'$\sigma_c$',fontsize=18)
plt.subplot(gs[1])
tmptrace = trace_d2['sigmac'][:]
plt.plot(tmptrace)
plt.title('Deduction Condition')
plt.ylabel(r'$\sigma_c$',fontsize=18)
plt.show()
def plot_samplerstat(burnin, trace):
# Sampler statistics
accept = trace.get_sampler_stats('mean_tree_accept', burn=burnin)
print('The accept rate is: %.5f' % (accept.mean()))
diverge = trace.get_sampler_stats('diverging')
print('Diverge of the trace')
print(diverge.nonzero())
energy = trace['energy']
energy_diff = np.diff(energy)
sns.distplot(energy - energy.mean(), label='energy')
sns.distplot(energy_diff, label='energy diff')
plt.legend()
plt.show()
plot_samplerstat(2000, trace_i)
The accept rate is: 0.84085 Diverge of the trace (array([ 52, 53, 54, 66, 127, 128, 162, 432, 527, 629, 720, 722, 723, 724, 734, 736, 738, 818, 1226, 1322, 1453, 1556, 1711, 2314, 2665, 2807, 2904, 3530, 3531, 3532, 3664, 3665, 3696, 3721, 3745, 4465, 4664, 4666, 4667, 4671, 4711, 4867, 4871, 5143, 5146]),)
plot_samplerstat(2000, trace_d)
The accept rate is: 0.85705 Diverge of the trace (array([ 33, 309, 461, 470, 484, 503, 580, 581, 582, 583, 585, 669, 725, 928, 929, 930, 931, 932, 938, 987, 1100, 1131, 1222, 1354, 1361, 1496, 1604, 1605, 1666, 1672, 1720, 1784, 1828, 1829, 1830, 1833, 2046, 2053, 2172, 2178, 2189, 2237, 2238, 2239, 2243, 2251, 2264, 2310, 2313, 2373, 2381, 2394, 2450, 2559, 2562, 2577, 2578, 2579, 2580, 2581, 2582, 2589, 2592, 2593, 2594, 2596, 2613, 2645, 2647, 2773, 2774, 2777, 2783, 2785, 2793, 2795, 2801, 2803, 2804, 2809, 2810, 2874, 2883, 2884, 2952, 2953, 2954, 3016, 3021, 3049, 3151, 3267, 3371, 3434, 3435, 3582, 3591, 3594, 3643, 3737, 3757, 3842, 4080, 4085, 4117, 4196, 4197, 4198, 4287, 4344, 4345, 4392, 4488, 4491, 4554, 4557, 4594, 4655, 4658, 4678, 4919, 4935, 4947, 4949, 4981, 4982, 4984, 4985, 4986, 4987, 4989, 5021, 5022, 5023, 5044, 5074, 5092, 5094, 5121, 5122, 5262, 5269, 5277, 5278, 5279, 5280, 5288, 5312, 5313, 5314, 5378, 5392, 5401, 5406, 5437, 5487, 5618, 5673, 5705, 5845, 5890, 5911, 5977]),)
plot_samplerstat(2000, trace_i2)
The accept rate is: 0.78346 Diverge of the trace (array([ 94, 104, 106, 111, 113, 114, 119, 124, 130, 134, 135, 136, 137, 142, 143, 150, 153, 154, 161, 162, 170, 171, 172, 174, 177, 185, 196, 197, 233, 269, 282, 293, 294, 300, 302, 327, 331, 335, 347, 352, 354, 355, 363, 365, 377, 378, 380, 382, 384, 385, 420, 424, 437, 474, 475, 476, 477, 478, 479, 509, 568, 581, 622, 661, 772, 774, 783, 804, 806, 810, 811, 823, 827, 828, 830, 867, 869, 907, 908, 909, 912, 913, 916, 934, 943, 951, 962, 975, 976, 1029, 1155, 1257, 1258, 1274, 1293, 1295, 1477, 1478, 1489, 1490, 1507, 1533, 1544, 1551, 1737, 1738, 1748, 1756, 1759, 1769, 1770, 1785, 1786, 1789, 1797, 1802, 1803, 1808, 1815, 1824, 1829, 1844, 1846, 1848, 1855, 2235, 2236, 2445, 2636, 2637, 2647, 2650, 2658, 2665, 2667, 2668, 2669, 2671, 2672, 2678, 2688, 2693, 2697, 2703, 2841, 2880, 2924, 2946, 3033, 3036, 3067, 3085, 3097, 3098, 3100, 3101, 3113, 3115, 3121, 3146, 3558, 3581, 3596, 3659, 3660, 3663, 3665, 3674, 3679, 3688, 3693, 3694, 3695, 3696, 3698, 3707, 3709, 3710, 3711, 3714, 3716, 3718, 3719, 3723, 3724, 3728, 3731, 3733, 3738, 3741, 3744, 3745, 3746, 3747, 3748, 3750, 3751, 3752, 3753, 3761, 3763, 3764, 3765, 3766, 3777, 3781, 3782, 3784, 3789, 3791, 3792, 3793, 3795, 3798, 3799, 3800, 3801, 3803, 3805, 3809, 3811, 3815, 3817, 3818, 3819, 3821, 3822, 3824, 3825, 3826, 3828, 3829, 3830, 3832, 3833, 3834, 3835, 3838, 3840, 3841, 3842, 3843, 3844, 3846, 3847, 3848, 3849, 3850, 3861, 3867, 3870, 3875, 3877, 3878, 3880, 3885, 3886, 3887, 3889, 3891, 3892, 3893, 3894, 3898, 3906, 3912, 3913, 3914, 3915, 3916, 3919, 3920, 3927, 3928, 3929, 3931, 3932, 3933, 3935, 3936, 3938, 3939, 3940, 3942, 3943, 3945, 3947, 3953, 3956, 3964, 3965, 3968, 3970, 3972, 3978, 3982, 3983, 3984, 3988, 3989, 3993, 3994, 3997, 4000, 4002, 4003, 4006, 4010, 4011, 4013, 4014, 4015, 4018, 4023, 4024, 4026, 4027, 4028, 4029, 4030, 4032, 4036, 4038, 4040, 4041, 4042, 4046, 4048, 4049, 4051, 4052, 4054, 4055, 4059, 4060, 4063, 4066, 4072, 4074, 4078, 4081, 4090, 4091, 4095, 4105, 4107, 4114, 4127, 4148, 4209, 4210, 4211, 4235, 4470, 4471, 4472, 4473, 4474, 4475, 4476, 4477, 4732, 4735, 4749, 4763, 4872, 4877, 4879, 4912, 4913, 4914, 5185, 5254, 5489, 5523, 5528, 5535, 5539, 5561, 5621, 5673, 5692, 5728, 5747, 5771, 5783, 5796, 5797, 5801, 5802, 5804, 5805, 5806, 5807, 5809, 5814, 5823, 5837, 5839, 5842, 5843, 5846, 5850, 5852, 5862, 5864, 5870, 5881, 5892, 5896, 5915, 5925, 5926, 5927, 5946, 5949, 5959, 5960, 5976, 5982, 5985, 5988]),)
plot_samplerstat(2000, trace_d2)
The accept rate is: 0.62782 Diverge of the trace (array([ 3, 8, 9, 10, 11, 13, 14, 15, 16, 18, 34, 69, 74, 77, 80, 85, 87, 88, 89, 92, 120, 122, 127, 138, 139, 140, 141, 142, 143, 145, 164, 285, 299, 327, 329, 335, 347, 350, 360, 361, 366, 369, 370, 371, 375, 378, 392, 399, 415, 416, 417, 418, 435, 450, 457, 465, 497, 498, 578, 584, 591, 623, 627, 652, 660, 681, 682, 689, 690, 691, 692, 693, 711, 739, 754, 770, 825, 851, 888, 891, 892, 894, 913, 914, 918, 927, 928, 934, 935, 942, 948, 951, 952, 959, 987, 989, 990, 1007, 1018, 1021, 1023, 1043, 1045, 1051, 1052, 1053, 1059, 1065, 1071, 1074, 1088, 1126, 1137, 1143, 1144, 1186, 1187, 1213, 1224, 1286, 1289, 1294, 1296, 1297, 1298, 1299, 1300, 1301, 1302, 1304, 1366, 1367, 1368, 1373, 1382, 1394, 1395, 1412, 1490, 1502, 1503, 1509, 1514, 1521, 1526, 1564, 1565, 1573, 1575, 1598, 1603, 1604, 1608, 1609, 1771, 1780, 1785, 1791, 1792, 1810, 1824, 1828, 1833, 1834, 1838, 1839, 1840, 1853, 1854, 1864, 1869, 1870, 1915, 1929, 1998, 2004, 2029, 2030, 2057, 2062, 2063, 2064, 2087, 2090, 2094, 2095, 2096, 2102, 2105, 2106, 2108, 2113, 2119, 2120, 2122, 2123, 2124, 2126, 2134, 2141, 2143, 2151, 2155, 2156, 2166, 2167, 2185, 2186, 2187, 2189, 2193, 2201, 2203, 2204, 2205, 2206, 2207, 2209, 2210, 2211, 2212, 2213, 2214, 2216, 2248, 2253, 2264, 2290, 2293, 2352, 2376, 2391, 2454, 2477, 2479, 2483, 2522, 2673, 2687, 2693, 2712, 2764, 2806, 2824, 2826, 2831, 2832, 2838, 2839, 2840, 2854, 2868, 2870, 2881, 2888, 2895, 2905, 2906, 2911, 2913, 2915, 2916, 2920, 2924, 2926, 2948, 2951, 2955, 2960, 2964, 2965, 2966, 2967, 2969, 2980, 2982, 2983, 2986, 2988, 2989, 2993, 2995, 2996, 2997, 2998, 2999, 3021, 3023, 3027, 3034, 3035, 3037, 3038, 3040, 3041, 3053, 3054, 3056, 3057, 3073, 3074, 3086, 3087, 3088, 3090, 3094, 3095, 3097, 3107, 3113, 3115, 3117, 3120, 3121, 3132, 3133, 3146, 3182, 3183, 3186, 3271, 3276, 3278, 3296, 3307, 3309, 3310, 3312, 3316, 3317, 3319, 3321, 3329, 3330, 3331, 3333, 3334, 3336, 3352, 3353, 3371, 3376, 3377, 3380, 3384, 3385, 3386, 3413, 3414, 3427, 3452, 3453, 3460, 3503, 3504, 3510, 3544, 3550, 3562, 3563, 3619, 3620, 3624, 3625, 3630, 3632, 3640, 3764, 3778, 3783, 3785, 3819, 3833, 3834, 3866, 3936, 3957, 3959, 3970, 3976, 4017, 4025, 4027, 4029, 4034, 4036, 4038, 4040, 4042, 4043, 4044, 4045, 4046, 4047, 4051, 4052, 4053, 4054, 4055, 4059, 4060, 4063, 4105, 4114, 4117, 4122, 4123, 4135, 4138, 4139, 4147, 4149, 4150, 4163, 4164, 4165, 4177, 4222, 4290, 4294, 4323, 4370, 4390, 4405, 4466, 4474, 4553, 4554, 4563, 4567, 4595, 4672, 4676, 4721, 4722, 4724, 4732, 4733, 4734, 4735, 4737, 4738, 4742, 4743, 4745, 4752, 4753, 4754, 4755, 4761, 4764, 4765, 4775, 4776, 4803, 4806, 4884, 4885, 4886, 4887, 4888, 4889, 4890, 4894, 4898, 4915, 4916, 4919, 4923, 4928, 4936, 4937, 4948, 4955, 4958, 4959, 4970, 4971, 4973, 4976, 4983, 4991, 4992, 4998, 5074, 5084, 5086, 5107, 5122, 5126, 5149, 5285, 5327, 5334, 5335, 5336, 5337, 5338, 5339, 5349, 5354, 5356, 5359, 5360, 5363, 5368, 5371, 5375, 5377, 5378, 5381, 5382, 5383, 5384, 5387, 5388, 5389, 5390, 5391, 5392, 5393, 5394, 5396, 5397, 5398, 5399, 5400, 5402, 5403, 5405, 5408, 5409, 5410, 5411, 5412, 5413, 5414, 5416, 5417, 5418, 5419, 5420, 5421, 5422, 5423, 5424, 5425, 5426, 5428, 5429, 5430, 5431, 5433, 5434, 5435, 5438, 5454, 5457, 5465, 5468, 5469, 5470, 5472, 5473, 5474, 5477, 5478, 5480, 5498, 5501, 5502, 5514, 5515, 5611, 5613, 5626, 5629, 5664, 5674, 5715, 5790, 5791, 5796, 5797, 5864, 5866, 5867, 5905, 5906, 5907, 5908, 5925, 5958, 5962, 5969, 5974, 5979, 5981, 5985, 5986, 5990, 5995, 5996]),)
As shown above, there are a lot of divergences in the trace, and the energy plot is very different from the energy_diff. This is a strong indication of bias in the estimation, and better reparameterization is needed.
Moreover, the reparameterization, which works better in BUGS/JAGS using Gibbs sampler, actually perform worse using NUTS. Again, this demonstrates that many of the tricks and intuition we got using BUGS/JAGS might not translate to PyMC3 and STAN.