using Pkg; Pkg.activate("../../."); Pkg.instantiate(); using IJulia; try IJulia.clear_output(); catch _ end using JLD using Statistics using LinearAlgebra using Distributions using RxInfer using ColorSchemes using LaTeXStrings using Plots default(label="", grid=false, linewidth=3, margin=10Plots.pt) # Load data from file data = load("../datasets/shaking_buildings.jld") # Data states = data["states"] observations = data["observations"] # Parameters mass = data["m"] friction = data["c"] stiffness = data["k"] # Measurement noise variance σ = data["σ"] # Time Δt = data["Δt"] T = length(observations) time = range(1,step=Δt,length=T) plot(time, states[1,:], color="red", label="states", xlabel="time (sec)", ylabel="train position") scatter!(time, observations, color="black", label="observations", legend=:topleft, size=(800,300)) # Transition matrix A = [1 Δt; -stiffness/mass*Δt -friction/mass*Δt+1] # Emission matrix C = [1.0, 0.0] # Set process noise covariance matrix Q = diagm(ones(2)) @model function LGDS(y, A,C,Q, σ, T) "State estimation in linear Gaussian dynamical system" # Prior state z_0 ~ MvNormalMeanCovariance(zeros(2), diageye(2)) z_kmin1 = z_0 for k in 1:T # State transition z[k] ~ MvNormalMeanCovariance(A * z_kmin1, Q) # Likelihood y[k] ~ NormalMeanVariance(dot(C, z[k]), σ) # Update recursive aux z_kmin1 = z[k] end end results = infer( model = LGDS(A=A,C=C,Q=Q, σ=σ, T=T), data = (y = [observations[k] for k in 1:T],), free_energy = true, ) m_z = cat(mean.(results.posteriors[:z])...,dims=2) v_z = cat(var.( results.posteriors[:z])...,dims=2) plot(time, states[1,:], color="red", label="states", xlabel="time (sec)", ylabel="train position") plot!(time, m_z[1,:], color="blue", ribbon=v_z[1,:], label="inferred") scatter!(time, observations, color="black", alpha=0.2, label="observations", legend=:bottomright, size=(800,300)) @model function LGDS_Q(y, A,C,σ,T) "State estimation in a linear Gaussian dynamical system with unknown process noise" # Prior state z_0 ~ MvNormalMeanCovariance(zeros(2), diageye(2)) # Process noise covariance matrix Q ~ InverseWishart(10, diageye(2)) z_kmin1 = z_0 for k in 1:T # State transition z[k] ~ MvNormalMeanCovariance(A * z_kmin1, Q) # Likelihood y[k] ~ NormalMeanVariance(dot(C, z[k]), σ^2) # Update recursive aux z_kmin1 = z[k] end end # Iterations of variational inference num_iters = 100 # Initialize variational marginal distributions init = @initialization begin q(z) = MvNormalMeanCovariance(zeros(2), diageye(2)) q(Q) = InverseWishart(10, diageye(2)) end # Define variational distribution factorization constraints = @constraints begin q(z_0,z,Q) = q(z_0, z)q(Q) end # Variational inference procedure results = infer( model = LGDS_Q(A=A,C=C, σ=σ, T=T), data = (y = [observations[k] for k in 1:T],), constraints = constraints, iterations = num_iters, options = (limit_stack_depth = 100,), initialization = init, free_energy = true, showprogress = true, ) plot(1:num_iters, results.free_energy, color="black", xscale=:log10, xlabel="Number of iterations", ylabel="Free Energy", size=(800,300)) m_z = cat(mean.(last(results.posteriors[:z]))...,dims=2) v_z = cat(var.(last(results.posteriors[:z]))...,dims=2) plot(time, states[1,:], color="red", label="states", xlabel="time (sec)", ylabel="train position") plot!(time, m_z[1,:], color="blue", ribbon=v_z[1,:], label="inferred") scatter!(time, observations, color="black", alpha=0.2, label="observations", legend=:topleft, size=(800,300)) Q_MAP = mean(last(results.posteriors[:Q])) # True data Q_true = data["Q"] # Colorbar limits clims = (minimum([Q_MAP[:]; Q_true[:]]), maximum([Q_MAP[:]; Q_true[:]])) # Plot covariance matrices as heatmaps p401 = heatmap(Q_MAP, axis=([], false), yflip=true, title="Estimated", clims=clims) p402 = heatmap(Q_true, axis=([], false), yflip=true, title="True", clims=clims) plot(p401,p402, layout=(1,2), size=(900,300))