import pymc as pm
import numpy as np
# Parmaters of two subjects
beta_params = [0.9,0.8]
lambda_params = [1.5,0.6]
delta_params = [1.5,0.6]
# Function
def performance(beta_param, lambda_param,delta_param):
p = beta_param * (1 - np.exp(-(x_init - delta_param) / lambda_param))
return np.random.binomial(1, p)#, p
# Generate input data
x_init = np.random.choice([2,3,4,5,6], 50)
correct1 = performance(beta_params[0], lambda_params[0], delta_params[0])
correct2 = performance(beta_params[1], lambda_params[1], delta_params[1])
codes = np.repeat([0,1], 50)
correct_init = np.concatenate([correct1, correct2])
x = np.concatenate([x_init, x_init])
# Implemented Optimization Value for x
x_optim = [4.2,3.1]
codes_optim = [0,1]
# print shapes
print(f"Shape of x: {x.shape}")
print(f"Shape of codes: {codes.shape}")
print(f"Shape of correct_init: {correct_init.shape}")
Shape of x: (100,) Shape of codes: (100,) Shape of correct_init: (100,)
import pytensor as pt
def index_fun(argmax_tensor):
''' Returns the corresponding input value of the index of the argmax value'''
x = pt.shared(np.linspace(0, 15,1000))
return x[argmax_tensor]
def argmax_fun(beta_param, lambda_param, delta_param):
x_init = np.array(np.linspace(0, 15, 1000))
x = np.tile(x_init, (2,1)).swapaxes(0,1)# Get correct shape for 2 participants
x = pt.shared(x) # Convert to pytensor
COST = np.array([10]) # Cost for optimization
probability_simul = beta_param * (1 - np.exp(-(x - delta_param) / lambda_param))
probability_simul_low = probability_simul * (150-(COST[0]*x))
maximum_low = pt.tensor.argmax(probability_simul_low,axis=0)
return index_fun(maximum_low)
coords = {
"participants": ["0","1"],
"participant_idxs": codes,
"obs_id": np.arange(len(codes)),
"obs_id_optim": np.arange(len(codes_optim)),
}
with pm.Model(coords=coords) as model:
participant_idx = pm.Data("participant_idx", codes, dims="obs_id")
participant_idx_optim = pm.Data("participant_idx_optim", codes_optim, dims="obs_id_optim")
# Three parameters of Performance Function
lambda_param = pm.TruncatedNormal("lambda_param", mu=0, sigma=1, dims="participants", lower=0)
beta_param = pm.TruncatedNormal("beta_param", mu=0, sigma=1, dims="participants", lower=0, upper=1)
delta_param = pm.TruncatedNormal("delta_param", mu=0, sigma=1, dims="participants", lower=0)
# Deviation Parameter (One for every participant)
sigma = pm.HalfNormal("simga_plan", sigma=1)
# Performance Function
probability = beta_param[participant_idx] * (1 - pm.math.exp(-(x - delta_param[participant_idx]) / lambda_param[participant_idx]))
# Sample from Performance Function
y = pm.Bernoulli('y', p=probability, observed=correct_init)
sampled_optim = argmax_fun(beta_param, lambda_param, delta_param)
pm.Deterministic('sampled_optim', sampled_optim, dims="participants")
implemented_optim = pm.Normal('implemented_optim', mu=sampled_optim[participant_idx_optim], sigma=sigma, observed=x_optim)
trace = pm.sample(1000, tune=1000, cores=4, chains=4, init="adapt_diag", return_inferencedata=True)
/Users/nwitzi01/miniconda3/envs/pymc_env/lib/python3.11/site-packages/pymc/data.py:433: UserWarning: The `mutable` kwarg was not specified. Before v4.1.0 it defaulted to `pm.Data(mutable=True)`, which is equivalent to using `pm.MutableData()`. In v4.1.0 the default changed to `pm.Data(mutable=False)`, equivalent to `pm.ConstantData`. Use `pm.ConstantData`/`pm.MutableData` or pass `pm.Data(..., mutable=False/True)` to avoid this warning. warnings.warn( Auto-assigning NUTS sampler... Initializing NUTS using adapt_diag... Multiprocess sampling (4 chains in 4 jobs) NUTS: [lambda_param, beta_param, delta_param, simga_plan]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 135 seconds. The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details Chain 0 reached the maximum tree depth. Increase `max_treedepth`, increase `target_accept` or reparameterize. Chain 1 reached the maximum tree depth. Increase `max_treedepth`, increase `target_accept` or reparameterize. Chain 2 reached the maximum tree depth. Increase `max_treedepth`, increase `target_accept` or reparameterize. Chain 3 reached the maximum tree depth. Increase `max_treedepth`, increase `target_accept` or reparameterize.
pm.summary(trace)
mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|
lambda_param[0] | 1.654 | 0.451 | 0.896 | 2.525 | 0.134 | 0.098 | 12.0 | 97.0 | 1.24 |
lambda_param[1] | 0.864 | 0.296 | 0.303 | 1.374 | 0.065 | 0.047 | 22.0 | 179.0 | 1.13 |
beta_param[0] | 0.866 | 0.079 | 0.742 | 0.998 | 0.019 | 0.014 | 17.0 | 38.0 | 1.16 |
beta_param[1] | 0.765 | 0.049 | 0.657 | 0.853 | 0.009 | 0.006 | 22.0 | 190.0 | 1.38 |
delta_param[0] | 0.764 | 0.436 | 0.051 | 1.478 | 0.138 | 0.101 | 11.0 | 33.0 | 1.27 |
delta_param[1] | 0.555 | 0.337 | 0.011 | 1.143 | 0.060 | 0.043 | 27.0 | 17.0 | 1.15 |
simga_plan | 0.577 | 0.429 | 0.021 | 1.366 | 0.084 | 0.060 | 15.0 | 39.0 | 1.20 |
sampled_optim[0] | 4.066 | 0.384 | 3.213 | 4.700 | 0.075 | 0.054 | 33.0 | 65.0 | 1.10 |
sampled_optim[1] | 2.842 | 0.476 | 2.027 | 3.769 | 0.097 | 0.071 | 23.0 | 89.0 | 1.14 |