using DifferentiableStateSpaceModels, LinearAlgebra, Turing, Zygote
using DifferentiableStateSpaceModels.Examples
using Turing: @addlogprob!
Turing.setadbackend(:zygote)
# Create models from modules and then solve
model_rbc = @include_example_module(Examples.rbc_observables)
# Generate artificial data for estimation
p_f = (ρ=0.2, δ=0.02, σ=0.01, Ω_1=0.01) # Fixed parameters
p_d = (α=0.5, β=0.95) # Pseudo-true values
sol = generate_perturbation(model_rbc, p_d, p_f, Val(1))
sol_second = generate_perturbation(model_rbc, p_d, p_f, Val(2))
T = 20
ϵ = [randn(model_rbc.n_ϵ) for _ in 1:T]
x0 = zeros(model_rbc.n_x)
fake_z = solve(sol, x0, (0, T), DifferentiableStateSpaceModels.LTI(); noise = ϵ).z
fake_z_second = solve(sol_second, x0, (0, T), DifferentiableStateSpaceModels.QTI(); noise = ϵ).z
┌ Info: Precompiling Turing [fce5fe82-541a-59a6-adf8-730c64b5f9a0] └ @ Base loading.jl:1342
21-element Vector{Vector{Float64}}: [7.824904812740593e-5, 0.0] [0.009388280821124171, -7.824904812740593e-5] [0.004141003230806073, 0.08564547473108383] [0.003113377412500409, 0.043761814038505646] [0.007700917546593067, 0.03114748946994769] [0.0028219400485847866, 0.07243491327720691] [0.014351652682013355, 0.030627394848069485] [0.014965386945969147, 0.1336674064514162] [0.014501839251820175, 0.1470204417729682] [0.012121007622016115, 0.14375484164268565] [0.006158082144511833, 0.12158897809077801] [0.0022856243818403987, 0.06503713781067733] [-0.005567886073177921, 0.02513421406357805] [-0.007159495107005886, -0.05012136636762045] [-0.0009141646711710215, -0.07044723363875496] [0.006966287246748618, -0.014495428593979284] [0.013428142860868263, 0.06225784779844589] [0.013777970823039514, 0.1275144960461263] [0.009452327892967818, 0.13562318681100563] [0.009420130597109135, 0.09641143723837649] [0.0071400525559970375, 0.09316115727178581]
## Estimation example: first-order, marginal likelihood approach
# Turing model definition
@model function rbc_kalman(z, m, p_f, cache)
α ~ Uniform(0.2, 0.8)
β ~ Uniform(0.5, 0.99)
p_d = (α = α, β = β)
sol = generate_perturbation(m, p_d, p_f, Val(1); cache)
if !(sol.retcode == :Success)
@addlogprob! -Inf
return
end
@addlogprob! solve(sol, sol.x_ergodic, (0, length(z)); observables = z).logpdf
end
c = SolverCache(model_rbc, Val(1), p_d)
turing_model = rbc_kalman(fake_z, model_rbc, p_f, c)
n_samples = 1000
n_adapts = 100
δ = 0.65
chain = sample(turing_model, NUTS(n_adapts, δ), n_samples; progress = true)
┌ Info: Found initial step size
│ ϵ = 0.2
└ @ Turing.Inference C:\Users\wupei\.julia\packages\Turing\nfMhU\src\inference\hmc.jl:188
┌ Warning: The current proposal will be rejected due to numerical error(s).
│ isfinite.((θ, r, ℓπ, ℓκ)) = (true, false, false, false)
└ @ AdvancedHMC C:\Users\wupei\.julia\packages\AdvancedHMC\HQHnm\src\hamiltonian.jl:47
┌ Warning: The current proposal will be rejected due to numerical error(s).
│ isfinite.((θ, r, ℓπ, ℓκ)) = (true, false, false, false)
└ @ AdvancedHMC C:\Users\wupei\.julia\packages\AdvancedHMC\HQHnm\src\hamiltonian.jl:47
┌ Warning: The current proposal will be rejected due to numerical error(s).
│ isfinite.((θ, r, ℓπ, ℓκ)) = (true, false, false, false)
└ @ AdvancedHMC C:\Users\wupei\.julia\packages\AdvancedHMC\HQHnm\src\hamiltonian.jl:47
Sampling: 100%|█████████████████████████████████████████| Time: 0:00:42
Chains MCMC chain (1000×14×1 Array{Float64, 3}): Iterations = 101:1:1100 Number of chains = 1 Samples per chain = 1000 Wall duration = 121.72 seconds Compute duration = 121.72 seconds parameters = α, β internals = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size Summary Statistics parameters mean std naive_se mcse ess rhat e ⋯ Symbol Float64 Float64 Float64 Float64 Float64 Float64 ⋯ α 0.4584 0.0338 0.0011 0.0021 237.0054 0.9998 ⋯ β 0.9516 0.0144 0.0005 0.0009 212.9579 1.0014 ⋯ 1 column omitted Quantiles parameters 2.5% 25.0% 50.0% 75.0% 97.5% Symbol Float64 Float64 Float64 Float64 Float64 α 0.3922 0.4357 0.4582 0.4806 0.5219 β 0.9212 0.9424 0.9530 0.9614 0.9771
## Estimation example: first-order, joint likelihood approach
# Turing model definition
@model function rbc_joint(z, m, p_f, cache, x0 = zeros(m.n_x))
α ~ Uniform(0.2, 0.8)
β ~ Uniform(0.5, 0.99)
p_d = (α = α, β = β)
T = length(z)
ϵ_draw ~ MvNormal(T, 1.0)
ϵ = map(i -> ϵ_draw[((i-1)*m.n_ϵ+1):(i*m.n_ϵ)], 1:T)
# println(p_d)
sol = generate_perturbation(m, p_d, p_f, Val(1); cache)
if !(sol.retcode == :Success)
@addlogprob! -Inf
return
end
@addlogprob! solve(sol, x0, (0, T); noise = ϵ, observables = z).logpdf
end
c = SolverCache(model_rbc, Val(1), p_d)
turing_model = rbc_joint(fake_z, model_rbc, p_f, c)
n_samples = 1000
n_adapts = 100
δ = 0.65
max_depth = 5 # A lower max_depth will lead to higher autocorrelation of samples, but faster. The time complexity is approximately 2^max_depth
chain = sample(turing_model, NUTS(n_adapts, δ; max_depth), n_samples; progress = true)
┌ Info: Found initial step size
│ ϵ = 0.21250000000000002
└ @ Turing.Inference C:\Users\wupei\.julia\packages\Turing\nfMhU\src\inference\hmc.jl:188
Sampling: 100%|█████████████████████████████████████████| Time: 0:02:26
Chains MCMC chain (1000×35×1 Array{Float64, 3}): Iterations = 101:1:1100 Number of chains = 1 Samples per chain = 1000 Wall duration = 182.9 seconds Compute duration = 182.9 seconds parameters = α, β, ϵ_draw[1], ϵ_draw[2], ϵ_draw[3], ϵ_draw[4], ϵ_draw[5], ϵ_draw[6], ϵ_draw[7], ϵ_draw[8], ϵ_draw[9], ϵ_draw[10], ϵ_draw[11], ϵ_draw[12], ϵ_draw[13], ϵ_draw[14], ϵ_draw[15], ϵ_draw[16], ϵ_draw[17], ϵ_draw[18], ϵ_draw[19], ϵ_draw[20], ϵ_draw[21] internals = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size Summary Statistics parameters mean std naive_se mcse ess rhat e ⋯ Symbol Float64 Float64 Float64 Float64 Float64 Float64 ⋯ α 0.4581 0.0303 0.0010 0.0029 88.5694 1.0073 ⋯ β 0.9530 0.0129 0.0004 0.0010 137.1778 0.9998 ⋯ ϵ_draw[1] -0.1112 0.2148 0.0068 0.0135 254.0827 1.0018 ⋯ ϵ_draw[2] -1.6127 0.3801 0.0120 0.0384 78.9215 1.0020 ⋯ ϵ_draw[3] 0.9781 0.3268 0.0103 0.0275 118.5946 0.9990 ⋯ ϵ_draw[4] 0.0767 0.3007 0.0095 0.0215 191.5072 0.9991 ⋯ ϵ_draw[5] -0.8353 0.3071 0.0097 0.0237 137.3091 1.0009 ⋯ ϵ_draw[6] 0.7644 0.3130 0.0099 0.0212 201.3213 1.0058 ⋯ ϵ_draw[7] -2.1753 0.4250 0.0134 0.0453 55.7146 1.0245 ⋯ ϵ_draw[8] -0.1410 0.3344 0.0106 0.0282 138.6009 1.0063 ⋯ ϵ_draw[9] 0.0578 0.3263 0.0103 0.0282 123.2021 1.0167 ⋯ ϵ_draw[10] 0.3613 0.2978 0.0094 0.0225 150.6162 1.0193 ⋯ ϵ_draw[11] 0.9990 0.3364 0.0106 0.0292 109.8132 1.0318 ⋯ ϵ_draw[12] 0.7090 0.3410 0.0108 0.0311 110.5820 1.0041 ⋯ ϵ_draw[13] 1.3763 0.3569 0.0113 0.0318 121.6980 0.9990 ⋯ ϵ_draw[14] 0.1758 0.3145 0.0099 0.0258 97.5849 1.0293 ⋯ ϵ_draw[15] -1.2367 0.3541 0.0112 0.0360 45.1033 1.0469 ⋯ ϵ_draw[16] -1.4202 0.3882 0.0123 0.0333 132.8966 1.0016 ⋯ ϵ_draw[17] -1.1673 0.3560 0.0113 0.0328 98.0627 1.0097 ⋯ ϵ_draw[18] -0.0066 0.2863 0.0091 0.0192 212.5723 0.9990 ⋯ ϵ_draw[19] 0.6946 0.2987 0.0094 0.0234 134.0651 1.0145 ⋯ ϵ_draw[20] -0.0934 0.3022 0.0096 0.0224 178.9008 1.0169 ⋯ ϵ_draw[21] 0.2455 0.8400 0.0266 0.1138 32.7431 1.0297 ⋯ 1 column omitted Quantiles parameters 2.5% 25.0% 50.0% 75.0% 97.5% Symbol Float64 Float64 Float64 Float64 Float64 α 0.3986 0.4381 0.4565 0.4796 0.5132 β 0.9272 0.9441 0.9533 0.9627 0.9769 ϵ_draw[1] -0.5821 -0.2391 -0.1085 0.0325 0.2853 ϵ_draw[2] -2.3621 -1.8703 -1.5998 -1.3407 -0.9117 ϵ_draw[3] 0.3892 0.7321 0.9788 1.1906 1.6516 ϵ_draw[4] -0.5363 -0.1241 0.0813 0.2777 0.6609 ϵ_draw[5] -1.4308 -1.0292 -0.8373 -0.6355 -0.2063 ϵ_draw[6] 0.1613 0.5525 0.7625 0.9680 1.3617 ϵ_draw[7] -3.0713 -2.4564 -2.1515 -1.8744 -1.4148 ϵ_draw[8] -0.8838 -0.3437 -0.1170 0.0930 0.4603 ϵ_draw[9] -0.5789 -0.1598 0.0502 0.2642 0.7354 ϵ_draw[10] -0.2095 0.1575 0.3489 0.5470 0.9839 ϵ_draw[11] 0.3521 0.7656 0.9840 1.2140 1.6836 ϵ_draw[12] 0.0783 0.4813 0.6932 0.9084 1.4197 ϵ_draw[13] 0.6989 1.1431 1.3750 1.6057 2.1030 ϵ_draw[14] -0.4388 -0.0387 0.1694 0.3707 0.8161 ϵ_draw[15] -1.9408 -1.4770 -1.2262 -1.0006 -0.5732 ϵ_draw[16] -2.2357 -1.6557 -1.4001 -1.1456 -0.7469 ϵ_draw[17] -1.9231 -1.3939 -1.1531 -0.9090 -0.5508 ϵ_draw[18] -0.5756 -0.1715 -0.0108 0.1741 0.5813 ϵ_draw[19] 0.1203 0.4916 0.6926 0.8879 1.3186 ϵ_draw[20] -0.7047 -0.2895 -0.0822 0.1252 0.4489 ϵ_draw[21] -1.4493 -0.2894 0.2262 0.8156 1.9122
## Estimation example: second-order, joint likelihood approach
# Turing model definition
@model function rbc_second(z, m, p_f, cache, x0 = zeros(m.n_x))
α ~ Uniform(0.2, 0.8)
β ~ Uniform(0.5, 0.99)
p_d = (α = α, β = β)
T = length(z)
ϵ_draw ~ MvNormal(T, 1.0)
ϵ = map(i -> ϵ_draw[((i-1)*m.n_ϵ+1):(i*m.n_ϵ)], 1:T)
sol = generate_perturbation(m, p_d, p_f, Val(2); cache)
if !(sol.retcode == :Success)
@addlogprob! -Inf
return
end
@addlogprob! solve(sol, x0, (0, T); noise = ϵ, observables = z).logpdf
end
c = SolverCache(model_rbc, Val(2), p_d)
turing_model = rbc_second(fake_z_second, model_rbc, p_f, c)
n_samples = 1000
n_adapts = 100
δ = 0.65
max_depth = 5
chain = sample(turing_model, NUTS(n_adapts, δ; max_depth), n_samples; progress = true)
┌ Info: Found initial step size
│ ϵ = 0.3361328125
└ @ Turing.Inference C:\Users\wupei\.julia\packages\Turing\nfMhU\src\inference\hmc.jl:188
Sampling: 100%|█████████████████████████████████████████| Time: 0:04:48
Chains MCMC chain (1000×35×1 Array{Float64, 3}): Iterations = 101:1:1100 Number of chains = 1 Samples per chain = 1000 Wall duration = 314.0 seconds Compute duration = 314.0 seconds parameters = α, β, ϵ_draw[1], ϵ_draw[2], ϵ_draw[3], ϵ_draw[4], ϵ_draw[5], ϵ_draw[6], ϵ_draw[7], ϵ_draw[8], ϵ_draw[9], ϵ_draw[10], ϵ_draw[11], ϵ_draw[12], ϵ_draw[13], ϵ_draw[14], ϵ_draw[15], ϵ_draw[16], ϵ_draw[17], ϵ_draw[18], ϵ_draw[19], ϵ_draw[20], ϵ_draw[21] internals = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size Summary Statistics parameters mean std naive_se mcse ess rhat e ⋯ Symbol Float64 Float64 Float64 Float64 Float64 Float64 ⋯ α 0.4464 0.0365 0.0012 0.0048 21.0310 1.0559 ⋯ β 0.9576 0.0163 0.0005 0.0020 23.2836 1.0623 ⋯ ϵ_draw[1] -0.1132 0.2236 0.0071 0.0129 322.5306 1.0038 ⋯ ϵ_draw[2] -1.6483 0.3299 0.0104 0.0257 135.1192 1.0057 ⋯ ϵ_draw[3] 0.9946 0.3257 0.0103 0.0228 162.2103 0.9990 ⋯ ϵ_draw[4] 0.0842 0.3149 0.0100 0.0204 195.4256 1.0053 ⋯ ϵ_draw[5] -0.8197 0.3311 0.0105 0.0244 169.4104 0.9991 ⋯ ϵ_draw[6] 0.7392 0.3197 0.0101 0.0235 180.6771 0.9997 ⋯ ϵ_draw[7] -2.1800 0.3829 0.0121 0.0353 94.6605 0.9994 ⋯ ϵ_draw[8] -0.1633 0.3227 0.0102 0.0218 209.1768 0.9991 ⋯ ϵ_draw[9] 0.0988 0.3416 0.0108 0.0276 148.3484 0.9994 ⋯ ϵ_draw[10] 0.3421 0.3350 0.0106 0.0224 192.7470 0.9995 ⋯ ϵ_draw[11] 1.0727 0.3705 0.0117 0.0339 124.0563 0.9991 ⋯ ϵ_draw[12] 0.6768 0.3368 0.0106 0.0254 166.8420 1.0030 ⋯ ϵ_draw[13] 1.4199 0.3660 0.0116 0.0297 143.1292 0.9990 ⋯ ϵ_draw[14] 0.1841 0.3071 0.0097 0.0211 189.1216 1.0200 ⋯ ϵ_draw[15] -1.2584 0.3282 0.0104 0.0214 207.2431 1.0041 ⋯ ϵ_draw[16] -1.4749 0.3792 0.0120 0.0325 117.7249 1.0022 ⋯ ϵ_draw[17] -1.1119 0.3349 0.0106 0.0248 155.6362 1.0149 ⋯ ϵ_draw[18] -0.0729 0.3063 0.0097 0.0241 157.5904 1.0055 ⋯ ϵ_draw[19] 0.7628 0.3302 0.0104 0.0266 153.4378 0.9990 ⋯ ϵ_draw[20] -0.1290 0.3373 0.0107 0.0244 181.6521 0.9993 ⋯ ϵ_draw[21] 0.1036 1.0803 0.0342 0.1573 26.9836 0.9996 ⋯ 1 column omitted Quantiles parameters 2.5% 25.0% 50.0% 75.0% 97.5% Symbol Float64 Float64 Float64 Float64 Float64 α 0.3680 0.4224 0.4504 0.4719 0.5092 β 0.9292 0.9452 0.9559 0.9681 0.9890 ϵ_draw[1] -0.5377 -0.2575 -0.1143 0.0366 0.3093 ϵ_draw[2] -2.3022 -1.8760 -1.6345 -1.4229 -1.0078 ϵ_draw[3] 0.3920 0.7796 0.9840 1.2017 1.6058 ϵ_draw[4] -0.4895 -0.1231 0.0836 0.2986 0.7042 ϵ_draw[5] -1.4837 -1.0168 -0.8432 -0.6017 -0.1574 ϵ_draw[6] 0.1416 0.5062 0.7398 0.9489 1.3727 ϵ_draw[7] -2.9379 -2.4387 -2.1843 -1.9008 -1.4795 ϵ_draw[8] -0.8892 -0.3809 -0.1370 0.0618 0.4209 ϵ_draw[9] -0.5651 -0.1390 0.1001 0.3227 0.7729 ϵ_draw[10] -0.2854 0.1179 0.3234 0.5570 1.0664 ϵ_draw[11] 0.4102 0.8309 1.0583 1.3086 1.8208 ϵ_draw[12] 0.0687 0.4465 0.6619 0.8864 1.3733 ϵ_draw[13] 0.7310 1.1664 1.4166 1.6612 2.1622 ϵ_draw[14] -0.4114 -0.0211 0.2044 0.3867 0.7812 ϵ_draw[15] -1.9786 -1.4656 -1.2507 -1.0293 -0.6627 ϵ_draw[16] -2.2371 -1.7418 -1.4459 -1.2001 -0.8466 ϵ_draw[17] -1.7675 -1.3303 -1.0980 -0.8747 -0.4988 ϵ_draw[18] -0.6794 -0.2647 -0.0651 0.1333 0.5072 ϵ_draw[19] 0.1462 0.5362 0.7524 0.9908 1.3992 ϵ_draw[20] -0.7910 -0.3529 -0.1329 0.0879 0.5805 ϵ_draw[21] -1.9288 -0.7344 0.1410 0.9009 2.3125