using Pkg; Pkg.activate("../."); Pkg.instantiate(); using IJulia; try IJulia.clear_output(); catch _ end using Plots, LinearAlgebra, LaTeXStrings # Parameters Σ = 1e5 * Diagonal(I,3) # Covariance matrix of prior on w σ2 = 2.0 # Noise variance # Generate data set w = [1.0; 2.0; 0.25] N = 30 z = 10.0*rand(N) x_train = [[1.0; z; z^2] for z in z] # Feature vector x = [1.0; z; z^2] f(x) = (w'*x)[1] y_train = map(f, x_train) + sqrt(σ2)*randn(N) # y[i] = w' * x[i] + ϵ scatter(z, y_train, label="data", xlabel=L"z", ylabel=L"f([1.0, z, z^2]) + \epsilon") using RxInfer, Random # Build model @model function linear_regression(y,x, N, Σ, σ2) w ~ MvNormalMeanCovariance(zeros(3),Σ) for i in 1:N y[i] ~ NormalMeanVariance(dot(w , x[i]), σ2) end end # Run message passing algorithm results = infer( model = linear_regression(N=length(x_train), Σ=Σ, σ2=σ2), data = (y = y_train, x = x_train), returnvars = (w = KeepLast(),), iterations = 20, ); # Plot result w = results.posteriors[:w] println("Posterior distribution of w: $(w)") plt = scatter(z, y_train, label="data", xlabel=L"z", ylabel=L"f([1.0, z, z^2]) + \epsilon") z_test = collect(0:0.2:12) x_test = [[1.0; z; z^2] for z in z_test] for i=1:10 w_sample = rand(results.posteriors[:w]) f_est(x) = (w_sample'*x)[1] plt = plot!(z_test, map(f_est, x_test), alpha=0.3, label=""); end display(plt) println("Forward message on Z:") @call_rule typeof(+)(:out, Marginalisation) (m_in1 = NormalMeanVariance(1.0, 1.0), m_in2 = NormalMeanVariance(2.0, 1.0)) println("Backward message on X:") @call_rule typeof(+)(:in1, Marginalisation) (m_out = NormalMeanVariance(3.0, 1.0), m_in2 = NormalMeanVariance(2.0, 1.0)) println("Forward message on Y:") @call_rule typeof(*)(:out, Marginalisation) (m_A = PointMass(4.0), m_in = NormalMeanVariance(1.0, 1.0)) println("Backward message on X:") @call_rule typeof(*)(:in, Marginalisation) (m_out = NormalMeanVariance(2.0, 1.0), m_A = PointMass(4.0)) # Data y1_hat = 1.0 y2_hat = 2.0 # Construct the factor graph @model function my_model(y1,y2) # `x` is the hidden states x ~ NormalMeanVariance(0.0, 4.0) # `y1` and `y2` are "clamped" observations y1 ~ NormalMeanVariance(x, 1.0) y2 ~ NormalMeanVariance(x, 2.0) return x end result = infer(model=my_model(), data=(y1=y1_hat, y2 = y2_hat,)) println("Sum-product message passing result: p(x|y1,y2) = 𝒩($(mean(result.posteriors[:x])),$(var(result.posteriors[:x])))") # Calculate mean and variance of p(x|y1,y2) manually by multiplying 3 Gaussians (see lesson 4 for details) v = 1 / (1/4 + 1/1 + 1/2) m = v * (0/4 + y1_hat/1.0 + y2_hat/2.0) println("Manual result: p(x|y1,y2) = 𝒩($(m), $(v))") open("../../styles/aipstyle.html") do f display("text/html", read(f, String)) end