import pymc3 as pm
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
%qtconsole --colors=linux
plt.style.use('ggplot')
from matplotlib import gridspec
from theano import tensor as tt
from scipy import stats
The take-the-best (TTB) model of decision-making (Gigerenzer & Goldstein, 1996) is a simple but influential account of how people choose between two stimuli on some criterion, and a good example of the general class of heuristic decision-making models (e.g., Gigerenzer & Todd, 1999; Gigerenzer & Gaissmaier, 2011; Payne, Bettman, & Johnson, 1990).
$$ t_q = \text{TTB}_{s}(\mathbf a_q,\mathbf b_q)$$$$ \gamma \sim \text{Uniform}(0.5,1)$$$$ y_{iq} \sim \begin{cases} \text{Bernoulli}(\gamma) & \text{if $t_q = a$} \\ \text{Bernoulli}(1- \gamma) & \text{if $t_q = b$} \\ \text{Bernoulli}(0.5) & \text{otherwise} \end{cases} $$
import scipy.io as sio
matdata = sio.loadmat('data/StopSearchData.mat')
y = np.squeeze(matdata['y'])
m = np.squeeze(np.float32(matdata['m']))
p = np.squeeze(matdata['p'])
v = np.squeeze(np.float32(matdata['v']))
x = np.squeeze(np.float32(matdata['x']))
# Constants
n, nc = np.shape(m) # number of stimuli and cues
nq, _ = np.shape(p) # number of questions
ns, _ = np.shape(y) # number of subjects
s = np.argsort(v) # s[1:nc] <- rank(v[1:nc])
t = []
# TTB Model For Each Question
for q in range(nq):
# Add Cue Contributions To Mimic TTB Decision
tmp1 = np.zeros(nc)
for j in range(nc):
tmp1[j] = (m[p[q, 0]-1, j]-m[p[q, 1]-1, j])*np.power(2, s[j])
# Find if Cue Favors First, Second, or Neither Stimulus
tmp2 = np.sum(tmp1)
tmp3 = -1*np.float32(-tmp2 > 0)+np.float32(tmp2 > 0)
t.append(tmp3+1)
t = np.asarray(t, dtype=int)
tmat = np.tile(t[np.newaxis, :], (ns, 1))
with pm.Model() as model1:
gamma = pm.Uniform('gamma', lower=.5, upper=1)
gammat = tt.stack([1-gamma, .5, gamma])
yiq = pm.Bernoulli('yiq', p=gammat[tmat], observed=y)
trace1 = pm.sample(3e3, njobs=2, tune=1000)
pm.traceplot(trace1);
Auto-assigning NUTS sampler... Initializing NUTS using ADVI... Average Loss = 338.3: 3%|▎ | 6745/200000 [00:00<00:20, 9472.55it/s] Convergence archived at 7400 Interrupted at 7,400 [3%]: Average Loss = 340.61 100%|██████████| 4000/4000.0 [00:02<00:00, 1418.51it/s]
ppc = pm.sample_ppc(trace1, samples=100, model=model1)
yiqpred = np.asarray(ppc['yiq'])
fig = plt.figure(figsize=(16, 8))
x1 = np.repeat(np.arange(ns)+1, nq).reshape(ns, -1).flatten()
y1 = np.repeat(np.arange(nq)+1, ns).reshape(nq, -1).T.flatten()
plt.scatter(y1, x1, s=np.mean(yiqpred, axis=0)*200, c='w')
plt.scatter(y1[y.flatten() == 1], x1[y.flatten() == 1], marker='x', c='r')
plt.plot(np.ones(100)*24.5, np.linspace(0, 21, 100), '--', lw=1.5, c='k')
plt.axis([0, 31, 0, 21])
plt.show()
100%|██████████| 100/100 [00:00<00:00, 129.04it/s]
A common comparison (e.g., Bergert & Nosofsky, 2007; Lee & Cummins, 2004) is between TTB and a model often called the Weighted ADDitive (WADD) model, which sums the evidence for both decision alternatives over all available cues, and chooses the one with the greatest evidence.
$$ \phi \sim \text{Uniform}(0,1)$$$$ z_i \sim \text{Bernoulli}(\phi)$$$$ \gamma \sim \text{Uniform}(0.5,1)$$$$ t_{iq} = \begin{cases} \text{TTB}\,(\mathbf a_q,\mathbf b_q) & \text{if $z_i = 1$} \\ \text{WADD}\,(\mathbf a_q,\mathbf b_q) & \text{if $z_i = 0$} \\ \end{cases} $$
$$ y_{iq} \sim \begin{cases} \text{Bernoulli}(\gamma) & \text{if $t_{iq} = a$} \\ \text{Bernoulli}(1- \gamma) & \text{if $t_{iq} = b$} \\ \text{Bernoulli}(0.5) & \text{otherwise} \end{cases} $$# Question cue contributions template
qcc = np.zeros((nq, nc))
for q in range(nq):
# Add Cue Contributions To Mimic TTB Decision
for j in range(nc):
qcc[q, j] = (m[p[q, 0]-1, j]-m[p[q, 1]-1, j])
qccmat = np.tile(qcc[np.newaxis, :, :], (ns, 1, 1))
# TTB Model For Each Question
s = np.argsort(v) # s[1:nc] <- rank(v[1:nc])
smat = np.tile(s[np.newaxis, :], (ns, nq, 1))
ttmp = np.sum(qccmat*np.power(2, smat), axis=2)
tmat = -1*(-ttmp > 0)+(ttmp > 0)+1
t = tmat[0]
# tmat = np.tile(t[np.newaxis, :], (ns, 1))
# WADD Model For Each Question
xmat = np.tile(x[np.newaxis, :], (ns, nq, 1))
wtmp = np.sum(qccmat*xmat, axis=2)
wmat = -1*(-wtmp > 0)+(wtmp > 0)+1
w = wmat[0]
print(t)
print(w)
[2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2] [2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 0 0 0 0 0 0]
with pm.Model() as model2:
phi = pm.Beta('phi', alpha=1, beta=1, testval=.01)
zi = pm.Bernoulli('zi', p=phi, shape=ns,
testval=np.asarray([1,1,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0]))
zi_ = tt.reshape(tt.repeat(zi, nq), (ns, nq))
gamma = pm.Uniform('gamma', lower=.5, upper=1)
gammat = tt.stack([1-gamma, .5, gamma])
t2 = tt.switch(tt.eq(zi_, 1), tmat, wmat)
yiq = pm.Bernoulli('yiq', p=gammat[t2], observed=y)
trace2 = pm.sample(3e3, njobs=2, tune=1000)
pm.traceplot(trace2);
Assigned NUTS to phi_logodds__ Assigned BinaryGibbsMetropolis to zi Assigned NUTS to gamma_interval__ 100%|██████████| 4000/4000.0 [00:18<00:00, 221.59it/s]
fig = plt.figure(figsize=(16, 4))
zitrc = trace2['zi'][1000:]
plt.bar(np.arange(ns)+1, 1-np.mean(zitrc, axis=0))
plt.yticks([0, 1], ('TTB', 'WADD'))
plt.xlabel('Subject')
plt.ylabel('Group')
plt.axis([0, 21, 0, 1])
plt.show()
$$ y_{iq} \sim \begin{cases} \text{Bernoulli}(\gamma) & \text{if $t_{iq} = a$} \\ \text{Bernoulli}(1- \gamma) & \text{if $t_{iq} = b$} \\ \text{Bernoulli}(0.5) & \text{otherwise} \end{cases} $$
with pm.Model() as model3:
gamma = pm.Uniform('gamma', lower=.5, upper=1)
gammat = tt.stack([1-gamma, .5, gamma])
v1 = pm.HalfNormal('v1', sd=1, shape=ns*nc)
s1 = pm.Deterministic('s1', tt.argsort(v1.reshape((ns, 1, nc)), axis=2))
smat2 = tt.tile(s1, (1, nq, 1)) # s[1:nc] <- rank(v[1:nc])
# TTB Model For Each Question
ttmp = tt.sum(qccmat*tt.power(2, smat2), axis=2)
tmat = -1*(-ttmp > 0)+(ttmp > 0)+1
yiq = pm.Bernoulli('yiq', p=gammat[tmat], observed=y)
It is important to notice here that, the sorting operation s[1:nc] <- rank(v[1:nc])
is likely breaks the smooth property in geometry. Method such as NUTS and ADVI is likely return wrong estimation as the nasty geometry will lead the sampler to stuck in some local miminal.
For this reason, we use Metropolis to sample from this model.
with model3:
# trace3 = pm.sample(3e3, njobs=2, tune=1000)
trace3 = pm.sample(1e5, step=pm.Metropolis(), njobs=2)
pm.traceplot(trace3, varnames=['gamma', 'v1']);
100%|██████████| 100500/100500.0 [01:57<00:00, 858.69it/s]
burnin = 50000
# v1trace = np.squeeze(trace3['v1'][burnin:])
# s1trace = np.argsort(v1trace, axis=2)
s1trace = np.squeeze(trace3[burnin:]['s1'])
for subj_id in [12, 13]:
subj_s = np.squeeze(s1trace[:,subj_id-1,:])
unique_ord = np.vstack({tuple(row) for row in subj_s})
num_display = 10
print('Subject %s' %(subj_id))
print('There are %s search orders sampled in the posterior.'%(unique_ord.shape[0]))
mass_ = []
for s_ in unique_ord:
mass_.append(np.mean(np.sum(subj_s == s_, axis=1) == len(s_)))
mass_ = np.asarray(mass_)
sortmass = np.argsort(mass_)[::-1]
for i in sortmass[:num_display]:
s_ = unique_ord[i]
print('Order=(' + str(s_+1) + '), Estimated Mass=' + str(mass_[i]))
Subject 12 There are 342 search orders sampled in the posterior. Order=([1 3 5 4 9 2 6 7 8]), Estimated Mass=0.06106 Order=([1 5 3 4 9 2 7 6 8]), Estimated Mass=0.04989 Order=([1 5 3 4 9 2 6 7 8]), Estimated Mass=0.03724 Order=([1 3 5 4 9 2 7 6 8]), Estimated Mass=0.0354 Order=([6 2 3 7 9 1 5 4 8]), Estimated Mass=0.02863 Order=([6 3 2 7 9 1 5 4 8]), Estimated Mass=0.02839 Order=([1 3 4 5 9 2 6 7 8]), Estimated Mass=0.02338 Order=([3 6 2 1 4 5 8 7 9]), Estimated Mass=0.02231 Order=([3 6 2 4 7 8 1 5 9]), Estimated Mass=0.01888 Order=([1 3 4 5 9 2 6 8 7]), Estimated Mass=0.0184 Subject 13 There are 214 search orders sampled in the posterior. Order=([6 1 7 8 9 5 2 4 3]), Estimated Mass=0.05278 Order=([6 1 7 8 9 5 3 2 4]), Estimated Mass=0.04959 Order=([5 8 1 3 6 9 7 4 2]), Estimated Mass=0.04332 Order=([6 7 1 2 8 9 3 5 4]), Estimated Mass=0.03655 Order=([6 1 7 9 8 5 3 4 2]), Estimated Mass=0.03116 Order=([6 7 2 1 8 9 3 5 4]), Estimated Mass=0.02788 Order=([6 1 7 8 9 5 2 3 4]), Estimated Mass=0.02453 Order=([1 8 3 5 6 9 7 4 2]), Estimated Mass=0.02317 Order=([6 7 1 2 8 9 5 3 4]), Estimated Mass=0.02253 Order=([8 5 3 1 6 9 2 4 7]), Estimated Mass=0.02165
The return order is not at all similar to the result in JAGS (as shown in the book on p.233). However, the cue 2 is searched before cue 6 in Subject 12 and vice versa in Subject 13, which is the same as in the book.
$$ t_{iq} =
\begin{cases}
\text{TTB}_{si}\,(\mathbf a_q,\mathbf b_q) & \text{if $z_{iq} = 1$} \\
\text{WADD}\,(\mathbf a_q,\mathbf b_q) & \text{if $z_{iq} = 0$} \\
\end{cases} $$
$$ y_{iq} \sim
\begin{cases}
\text{Bernoulli}(\gamma) & \text{if $t_{iq} = a$} \\
\text{Bernoulli}(1- \gamma) & \text{if $t_{iq} = b$} \\
\text{Bernoulli}(0.5) & \text{otherwise}
\end{cases} $$
with pm.Model() as model4:
phi = pm.Beta('phi', alpha=1, beta=1, testval=.01)
zi = pm.Bernoulli('zi', p=phi, shape=ns,
testval=np.asarray([1,1,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0]))
zi_ = tt.reshape(tt.repeat(zi, nq), (ns, nq))
gamma = pm.Uniform('gamma', lower=.5, upper=1)
gammat = tt.stack([1-gamma, .5, gamma])
v1 = pm.HalfNormal('v1', sd=1, shape=ns*nc)
s1 = pm.Deterministic('s1', tt.argsort(v1.reshape((ns, 1, nc)), axis=2))
smat2 = tt.tile(s1, (1, nq, 1)) # s[1:nc] <- rank(v[1:nc])
# TTB Model For Each Question
ttmp = tt.sum(qccmat*tt.power(2, smat2), axis=2)
tmat = -1*(-ttmp > 0) + (ttmp > 0) + 1
t2 = tt.switch(tt.eq(zi_, 1), tmat, wmat)
yiq = pm.Bernoulli('yiq', p=gammat[t2], observed=y)
trace4 = pm.sample(1e5, step=pm.Metropolis())
burnin=50000
pm.traceplot(trace4[burnin:], varnames=['phi', 'gamma']);
100%|██████████| 100500/100500.0 [02:38<00:00, 632.91it/s]
ppc = pm.sample_ppc(trace4[burnin:], samples=100, model=model4)
yiqpred = np.asarray(ppc['yiq'])
fig = plt.figure(figsize=(16, 8))
x1 = np.repeat(np.arange(ns)+1, nq).reshape(ns, -1).flatten()
y1 = np.repeat(np.arange(nq)+1, ns).reshape(nq, -1).T.flatten()
plt.scatter(y1, x1, s=np.mean(yiqpred, axis=0)*200, c='w')
plt.scatter(y1[y.flatten() == 1], x1[y.flatten() == 1], marker='x', c='r')
plt.plot(np.ones(100)*24.5, np.linspace(0, 21, 100), '--', lw=1.5, c='k')
plt.axis([0, 31, 0, 21])
plt.show()
100%|██████████| 100/100 [00:01<00:00, 89.13it/s]