using DifferentiableStateSpaceModels, DifferenceEquations, LinearAlgebra, Turing, Zygote
using DifferentiableStateSpaceModels.Examples
using Turing: @addlogprob!
Turing.setadbackend(:zygote)
# Create models from modules and then solve
model_rbc = @include_example_module(Examples.rbc_observables)
# Generate artificial data for estimation
p_f = (ρ = 0.2, δ = 0.02, σ = 0.01, Ω_1 = 0.01) # Fixed parameters
p_d = (α = 0.5, β = 0.95) # Pseudo-true values
sol = generate_perturbation(model_rbc, p_d, p_f, Val(1)) # Solution to the first-order RBC
sol_second = generate_perturbation(model_rbc, p_d, p_f, Val(2)) # Solution to the second-order RBC
T = 20
ϵ = [randn(model_rbc.n_ϵ) for _ in 1:T]
x0 = zeros(model_rbc.n_x)
# Create a first-order problem setting
problem_first_order = StateSpaceProblem(
DifferentiableStateSpaceModels.dssm_evolution,
DifferentiableStateSpaceModels.dssm_volatility,
DifferentiableStateSpaceModels.dssm_observation,
x0,
(0, T),
sol,
noise = DefinedNoise(ϵ),
)
# Generate fake data for first-order estimation exercises
fake_z = DifferenceEquations.solve(problem_first_order, NoiseConditionalFilter()).z[2:end]
# Create a second-order problem setting
problem_second_order = StateSpaceProblem(
DifferentiableStateSpaceModels.dssm_evolution,
DifferentiableStateSpaceModels.dssm_volatility,
DifferentiableStateSpaceModels.dssm_observation,
[x0; x0],
(0, T),
sol_second,
noise = DefinedNoise(ϵ),
)
# Generate fake data for second-order estimatino exercises
fake_z_second = DifferenceEquations.solve(problem_second_order, NoiseConditionalFilter()).z[2:end]
20-element Vector{Vector{Float64}}: [0.010334871038349262, -7.824904812740593e-5] [0.009423566208750341, 0.0943670667134858] [0.015008764884793328, 0.09303935790924596] [0.015016700100943674, 0.1443767660379526] [0.02113467304000077, 0.14829744045734816] [0.013036180623260741, 0.20494961851357657] [0.020128017797618636, 0.1346337678757667] [0.017169363491181408, 0.19465230103200418] [0.016957726243991545, 0.17190138588597445] [0.006459137337200757, 0.16824052733838765] [0.011332069203908293, 0.07134164588489617] [0.01427844664817287, 0.1088856414442139] [0.0036495297240985917, 0.13883510628593423] [-0.0009176958065385961, 0.04326796775724953] [0.002087202549504044, -0.005980498232478319] [0.008575421729870135, 0.01796423681254636] [0.003777787105670909, 0.07950304294265159] [-0.008164894519121842, 0.03995580512836877] [-0.005196677187614779, -0.0728703298489967] [-0.005802045134383482, -0.05409787878535121]
## Estimation example: first-order, marginal likelihood approach
# Turing model definition
@model function rbc_kalman(z, m, p_f, cache)
α ~ Uniform(0.2, 0.8)
β ~ Uniform(0.5, 0.99)
p_d = (; α, β)
sol = generate_perturbation(m, p_d, p_f, Val(1); cache)
if !(sol.retcode == :Success)
@addlogprob! -Inf
return
end
problem = LinearStateSpaceProblem(
sol.A,
sol.B,
sol.C,
sol.x_ergodic,
(0, T),
obs_noise = sol.D,
observables = z
)
@addlogprob! DifferenceEquations.solve(problem, KalmanFilter(); vectype = Zygote.Buffer).loglikelihood
end
c = SolverCache(model_rbc, Val(1), p_d)
turing_model = rbc_kalman(fake_z, model_rbc, p_f, c)
n_samples = 1000
n_adapts = 100
δ = 0.65
chain = sample(turing_model, NUTS(n_adapts, δ), n_samples; progress = true)
┌ Info: Found initial step size
│ ϵ = 0.4
└ @ Turing.Inference C:\Users\wupei\.julia\packages\Turing\nfMhU\src\inference\hmc.jl:188
┌ Warning: The current proposal will be rejected due to numerical error(s).
│ isfinite.((θ, r, ℓπ, ℓκ)) = (true, false, false, false)
└ @ AdvancedHMC C:\Users\wupei\.julia\packages\AdvancedHMC\HQHnm\src\hamiltonian.jl:47
Sampling: 100%|█████████████████████████████████████████| Time: 0:00:35
Chains MCMC chain (1000×14×1 Array{Float64, 3}): Iterations = 101:1:1100 Number of chains = 1 Samples per chain = 1000 Wall duration = 109.88 seconds Compute duration = 109.88 seconds parameters = α, β internals = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size Summary Statistics parameters mean std naive_se mcse ess rhat e ⋯ Symbol Float64 Float64 Float64 Float64 Float64 Float64 ⋯ α 0.4836 0.0313 0.0010 0.0017 264.0063 1.0024 ⋯ β 0.9514 0.0124 0.0004 0.0008 217.9021 1.0011 ⋯ 1 column omitted Quantiles parameters 2.5% 25.0% 50.0% 75.0% 97.5% Symbol Float64 Float64 Float64 Float64 Float64 α 0.4218 0.4626 0.4843 0.5063 0.5420 β 0.9273 0.9429 0.9511 0.9603 0.9763
# Turing model definition
@model function rbc_joint(z, m, p_f, cache::DifferentiableStateSpaceModels.AbstractSolverCache{Order}, x0) where {Order}
α ~ Uniform(0.2, 0.8)
β ~ Uniform(0.5, 0.99)
p_d = (; α, β)
T = length(z)
ϵ_draw ~ MvNormal(T, 1.0)
ϵ = map(i -> ϵ_draw[((i-1)*m.n_ϵ+1):(i*m.n_ϵ)], 1:T)
sol = generate_perturbation(m, p_d, p_f, Val(Order); cache)
if !(sol.retcode == :Success)
@addlogprob! -Inf
return
end
problem = StateSpaceProblem(
DifferentiableStateSpaceModels.dssm_evolution,
DifferentiableStateSpaceModels.dssm_volatility,
DifferentiableStateSpaceModels.dssm_observation,
x0,
(0, T),
sol,
noise = DefinedNoise(ϵ),
obs_noise = sol.D,
observables = z
)
@addlogprob! DifferenceEquations.solve(problem, NoiseConditionalFilter(); vectype = Zygote.Buffer).loglikelihood
end
rbc_joint (generic function with 2 methods)
## Estimation example: first-order, joint likelihood approach
c = SolverCache(model_rbc, Val(1), p_d)
turing_model = rbc_joint(fake_z, model_rbc, p_f, c, x0)
n_samples = 1000
n_adapts = 100
δ = 0.65
max_depth = 5 # A lower max_depth will lead to higher autocorrelation of samples, but faster. The time complexity is approximately 2^max_depth
chain = sample(turing_model, NUTS(n_adapts, δ; max_depth), n_samples; progress = true)
┌ Info: Found initial step size
│ ϵ = 0.05
└ @ Turing.Inference C:\Users\wupei\.julia\packages\Turing\nfMhU\src\inference\hmc.jl:188
Sampling: 100%|█████████████████████████████████████████| Time: 0:06:13
Chains MCMC chain (1000×34×1 Array{Float64, 3}): Iterations = 101:1:1100 Number of chains = 1 Samples per chain = 1000 Wall duration = 414.73 seconds Compute duration = 414.73 seconds parameters = α, β, ϵ_draw[1], ϵ_draw[2], ϵ_draw[3], ϵ_draw[4], ϵ_draw[5], ϵ_draw[6], ϵ_draw[7], ϵ_draw[8], ϵ_draw[9], ϵ_draw[10], ϵ_draw[11], ϵ_draw[12], ϵ_draw[13], ϵ_draw[14], ϵ_draw[15], ϵ_draw[16], ϵ_draw[17], ϵ_draw[18], ϵ_draw[19], ϵ_draw[20] internals = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size Summary Statistics parameters mean std naive_se mcse ess rhat e ⋯ Symbol Float64 Float64 Float64 Float64 Float64 Float64 ⋯ α 0.4889 0.0288 0.0009 0.0036 35.2358 1.0522 ⋯ β 0.9520 0.0120 0.0004 0.0013 45.5091 1.0589 ⋯ ϵ_draw[1] -1.5515 0.3074 0.0097 0.0383 37.9697 1.0003 ⋯ ϵ_draw[2] 0.1473 0.2480 0.0078 0.0170 193.6988 0.9993 ⋯ ϵ_draw[3] -0.8528 0.2912 0.0092 0.0285 75.4364 1.0005 ⋯ ϵ_draw[4] -0.0699 0.2707 0.0086 0.0160 171.8753 1.0023 ⋯ ϵ_draw[5] -0.9131 0.2879 0.0091 0.0254 103.6113 1.0049 ⋯ ϵ_draw[6] 1.0865 0.2763 0.0087 0.0291 50.9118 1.0045 ⋯ ϵ_draw[7] -1.1715 0.3107 0.0098 0.0362 39.9536 1.0007 ⋯ ϵ_draw[8] 0.3844 0.2492 0.0079 0.0224 102.5929 1.0082 ⋯ ϵ_draw[9] -0.0495 0.2456 0.0078 0.0218 132.6706 0.9999 ⋯ ϵ_draw[10] 1.4457 0.3374 0.0107 0.0391 39.8546 0.9996 ⋯ ϵ_draw[11] -0.9216 0.3009 0.0095 0.0302 53.6178 0.9990 ⋯ ϵ_draw[12] -0.4079 0.2736 0.0087 0.0244 105.1944 0.9996 ⋯ ϵ_draw[13] 1.5898 0.3777 0.0119 0.0432 53.8391 1.0056 ⋯ ϵ_draw[14] 0.5034 0.2611 0.0083 0.0181 173.2732 0.9991 ⋯ ϵ_draw[15] -0.5716 0.2583 0.0082 0.0205 153.0957 1.0010 ⋯ ϵ_draw[16] -0.9199 0.3002 0.0095 0.0311 60.8713 0.9993 ⋯ ϵ_draw[17] 0.8431 0.2875 0.0091 0.0255 103.9810 0.9993 ⋯ ϵ_draw[18] 1.6663 0.3654 0.0116 0.0446 34.7493 1.0069 ⋯ ϵ_draw[19] -0.5262 0.2448 0.0077 0.0179 133.2332 1.0015 ⋯ ϵ_draw[20] -0.1090 0.9267 0.0293 0.1376 22.2084 1.0215 ⋯ 1 column omitted Quantiles parameters 2.5% 25.0% 50.0% 75.0% 97.5% Symbol Float64 Float64 Float64 Float64 Float64 α 0.4247 0.4696 0.4906 0.5105 0.5369 β 0.9280 0.9431 0.9521 0.9603 0.9759 ϵ_draw[1] -2.2052 -1.7433 -1.5388 -1.3281 -1.0378 ϵ_draw[2] -0.3436 -0.0118 0.1557 0.3113 0.6143 ϵ_draw[3] -1.4672 -1.0372 -0.8302 -0.6428 -0.3525 ϵ_draw[4] -0.5882 -0.2514 -0.0673 0.1071 0.4528 ϵ_draw[5] -1.5317 -1.0692 -0.8828 -0.7225 -0.4115 ϵ_draw[6] 0.6095 0.8964 1.0735 1.2460 1.6498 ϵ_draw[7] -1.8498 -1.3470 -1.1514 -0.9616 -0.5351 ϵ_draw[8] -0.1080 0.2124 0.3835 0.5525 0.8336 ϵ_draw[9] -0.5099 -0.2153 -0.0630 0.1175 0.4310 ϵ_draw[10] 0.8709 1.1985 1.4251 1.6511 2.1718 ϵ_draw[11] -1.5299 -1.1119 -0.9076 -0.7102 -0.3978 ϵ_draw[12] -1.0554 -0.5630 -0.3741 -0.2273 0.0522 ϵ_draw[13] 0.9363 1.3469 1.5276 1.7903 2.5023 ϵ_draw[14] 0.0264 0.3377 0.4887 0.6602 1.0833 ϵ_draw[15] -1.0990 -0.7295 -0.5661 -0.3971 -0.0987 ϵ_draw[16] -1.5599 -1.1067 -0.9102 -0.7059 -0.4036 ϵ_draw[17] 0.3351 0.6316 0.8193 1.0431 1.4251 ϵ_draw[18] 0.9603 1.4167 1.6561 1.8970 2.4037 ϵ_draw[19] -1.0457 -0.6945 -0.5230 -0.3569 -0.0578 ϵ_draw[20] -1.9507 -0.7969 -0.0547 0.4905 1.7471
## Estimation example: second-order, joint likelihood approach
c = SolverCache(model_rbc, Val(2), p_d)
turing_model = rbc_joint(fake_z_second, model_rbc, p_f, c, [x0; x0])
n_samples = 1000
n_adapts = 100
δ = 0.65
max_depth = 5
chain = sample(turing_model, NUTS(n_adapts, δ; max_depth), n_samples; progress = true)
┌ Info: Found initial step size
│ ϵ = 0.0234375
└ @ Turing.Inference C:\Users\wupei\.julia\packages\Turing\nfMhU\src\inference\hmc.jl:188
Sampling: 100%|█████████████████████████████████████████| Time: 0:11:41
Chains MCMC chain (1000×34×1 Array{Float64, 3}): Iterations = 101:1:1100 Number of chains = 1 Samples per chain = 1000 Wall duration = 733.25 seconds Compute duration = 733.25 seconds parameters = α, β, ϵ_draw[1], ϵ_draw[2], ϵ_draw[3], ϵ_draw[4], ϵ_draw[5], ϵ_draw[6], ϵ_draw[7], ϵ_draw[8], ϵ_draw[9], ϵ_draw[10], ϵ_draw[11], ϵ_draw[12], ϵ_draw[13], ϵ_draw[14], ϵ_draw[15], ϵ_draw[16], ϵ_draw[17], ϵ_draw[18], ϵ_draw[19], ϵ_draw[20] internals = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size Summary Statistics parameters mean std naive_se mcse ess rhat e ⋯ Symbol Float64 Float64 Float64 Float64 Float64 Float64 ⋯ α 0.4851 0.0277 0.0009 0.0033 47.7713 1.0034 ⋯ β 0.9516 0.0117 0.0004 0.0011 109.5834 0.9991 ⋯ ϵ_draw[1] -1.6259 0.3007 0.0095 0.0359 27.8116 1.0512 ⋯ ϵ_draw[2] 0.1767 0.2877 0.0091 0.0244 132.9864 1.0295 ⋯ ϵ_draw[3] -0.8984 0.2706 0.0086 0.0226 105.9210 1.0078 ⋯ ϵ_draw[4] -0.0772 0.2661 0.0084 0.0188 137.7001 1.0022 ⋯ ϵ_draw[5] -0.9567 0.2770 0.0088 0.0244 111.7968 0.9990 ⋯ ϵ_draw[6] 1.1438 0.2881 0.0091 0.0290 93.5977 0.9990 ⋯ ϵ_draw[7] -1.2502 0.3054 0.0097 0.0326 77.9409 0.9990 ⋯ ϵ_draw[8] 0.4245 0.2882 0.0091 0.0275 91.1914 1.0205 ⋯ ϵ_draw[9] -0.0481 0.2780 0.0088 0.0257 100.9479 1.0214 ⋯ ϵ_draw[10] 1.4637 0.3309 0.0105 0.0398 43.4521 1.0008 ⋯ ϵ_draw[11] -0.9303 0.2915 0.0092 0.0312 59.4857 0.9998 ⋯ ϵ_draw[12] -0.4190 0.2737 0.0087 0.0233 135.5965 1.0066 ⋯ ϵ_draw[13] 1.6176 0.3119 0.0099 0.0350 68.1616 1.0044 ⋯ ϵ_draw[14] 0.5670 0.2869 0.0091 0.0305 62.3127 1.0079 ⋯ ϵ_draw[15] -0.6368 0.3100 0.0098 0.0352 66.7596 0.9996 ⋯ ϵ_draw[16] -0.8959 0.2832 0.0090 0.0241 130.1619 1.0047 ⋯ ϵ_draw[17] 0.8097 0.2728 0.0086 0.0233 126.4728 1.0038 ⋯ ϵ_draw[18] 1.7941 0.3642 0.0115 0.0461 38.8073 1.0083 ⋯ ϵ_draw[19] -0.5802 0.2736 0.0087 0.0244 88.7291 1.0114 ⋯ ϵ_draw[20] 0.2928 0.7435 0.0235 0.1059 26.3271 1.0233 ⋯ 1 column omitted Quantiles parameters 2.5% 25.0% 50.0% 75.0% 97.5% Symbol Float64 Float64 Float64 Float64 Float64 α 0.4320 0.4665 0.4844 0.5043 0.5423 β 0.9277 0.9436 0.9514 0.9605 0.9731 ϵ_draw[1] -2.2675 -1.8121 -1.5990 -1.4211 -1.1193 ϵ_draw[2] -0.3717 -0.0008 0.1735 0.3645 0.7345 ϵ_draw[3] -1.4678 -1.0661 -0.8958 -0.7273 -0.3913 ϵ_draw[4] -0.6733 -0.2338 -0.0469 0.1023 0.3698 ϵ_draw[5] -1.5560 -1.1290 -0.9402 -0.7731 -0.4387 ϵ_draw[6] 0.6390 0.9395 1.1241 1.3286 1.7685 ϵ_draw[7] -1.8451 -1.4584 -1.2365 -1.0299 -0.6967 ϵ_draw[8] -0.1489 0.2373 0.4233 0.6038 1.0384 ϵ_draw[9] -0.6376 -0.2133 -0.0412 0.1174 0.4903 ϵ_draw[10] 0.9047 1.2237 1.4416 1.6610 2.1796 ϵ_draw[11] -1.5662 -1.1097 -0.9065 -0.7262 -0.4449 ϵ_draw[12] -0.9483 -0.6080 -0.4127 -0.2518 0.1579 ϵ_draw[13] 1.0905 1.3935 1.6003 1.8286 2.2429 ϵ_draw[14] 0.0687 0.3566 0.5469 0.7492 1.1956 ϵ_draw[15] -1.3596 -0.8152 -0.6134 -0.4276 -0.1027 ϵ_draw[16] -1.4853 -1.0803 -0.8894 -0.6808 -0.3993 ϵ_draw[17] 0.3064 0.6345 0.7919 0.9709 1.3976 ϵ_draw[18] 1.1474 1.5310 1.7752 2.0542 2.5415 ϵ_draw[19] -1.1441 -0.7550 -0.5730 -0.3945 -0.0722 ϵ_draw[20] -1.1762 -0.2066 0.2604 0.8241 1.7359