using Pkg;Pkg.activate("probprog/workspace/");Pkg.instantiate(); IJulia.clear_output(); using PyPlot, ForneyLab, LinearAlgebra # Parameters Σ = 1e5 * Diagonal(I,3) # Covariance matrix of prior on w σ2 = 2.0 # Noise variance # Generate data set w = [1.0; 2.0; 0.25] N = 30 z = 10.0*rand(N) x_train = [[1.0; z; z^2] for z in z] # Feature vector x = [1.0; z; z^2] f(x) = (w'*x)[1] y_train = map(f, x_train) + sqrt(σ2)*randn(N) # y[i] = w' * x[i] + ϵ scatter(z, y_train); xlabel(L"z"); ylabel(L"f([1.0, z, z^2]) + \epsilon"); # Build factorgraph fg = FactorGraph() @RV w ~ GaussianMeanVariance(constant(zeros(3)), constant(Σ, id=:Σ), id=:w) # p(w) for t=1:N x_t = Variable(id=:x_*t) d_t = Variable(id=:d_*t) # d=w'*x DotProduct(d_t, x_t, w) # p(f|w,x) @RV y_t ~ GaussianMeanVariance(d_t, constant(σ2, id=:σ2_*t), id=:y_*t) # p(y|d) placeholder(x_t, :x, index=t, dims=(3,)) placeholder(y_t, :y, index=t); end # Build and run message passing algorithm algo = messagePassingAlgorithm(w) source_code = algorithmSourceCode(algo) eval(Meta.parse(source_code)) data = Dict(:x => x_train, :y => y_train) w_posterior_dist = step!(data)[:w] # Plot result println("Posterior distribution of w: $(w_posterior_dist)") scatter(z, y_train); xlabel(L"z"); ylabel(L"f([1.0, z, z^2]) + \epsilon"); z_test = collect(0:0.2:12) x_test = [[1.0; z; z^2] for z in z_test] for sample=1:10 w = ForneyLab.sample(w_posterior_dist) f_est(x) = (w'*x)[1] plot(z_test, map(f_est, x_test), "k-", alpha=0.3); end # Forward message towards Z fg = FactorGraph() @RV x ~ GaussianMeanVariance(constant(1.0), constant(1.0), id=:x) @RV y ~ GaussianMeanVariance(constant(2.0), constant(1.0), id=:y) @RV z = x + y; z.id = :z q = PosteriorFactorization(fg) algo1 = messagePassingAlgorithm(z, id=:_forward_Z) source_code1 = algorithmSourceCode(algo1) eval(Meta.parse(source_code1)) msg_forward_Z = step_forward_Z!(Dict())[:z] print("Forward message on Z: $(msg_forward_Z)") # Backward message towards X fg = FactorGraph() @RV x; x.id=:x @RV y ~ GaussianMeanVariance(constant(2.0), constant(1.0), id=:y) @RV z = x + y GaussianMeanVariance(z, constant(3.0), constant(1.0), id=:z) q = PosteriorFactorization(fg) algo2 = messagePassingAlgorithm(x, id=:_backward_X) source_code2 = algorithmSourceCode(algo2) eval(Meta.parse(source_code2)) msg_backward_X = step_backward_X!(Dict())[:x] print("Backward message on X: $(msg_backward_X)") # Forward message towards Y fg = FactorGraph() @RV x ~ GaussianMeanVariance(1.0, 1.0) @RV y = 4.0 * x q = PosteriorFactorization(fg) #This is where the bugs live.. algo3 = messagePassingAlgorithm(y, id=:_y_fwd) source_code3 = algorithmSourceCode(algo3) eval(Meta.parse(source_code3)) msg_forward_Y = step_y_fwd!(Dict())[:y] print("Forward message on Y: $(msg_forward_Y)") using ForneyLab # Data y1_hat = 1.0 y2_hat = 2.0 # Construct the factor graph fg = FactorGraph() @RV x ~ GaussianMeanVariance(constant(0.0), constant(4.0), id=:x) # Node p(x) @RV y1 ~ GaussianMeanVariance(x, constant(1.0)) # Node p(y1|x) @RV y2 ~ GaussianMeanVariance(x, constant(2.0)) # Node p(y2|x) Clamp(y1, y1_hat) # Terminal (clamp) node for y1 Clamp(y2, y2_hat) # Terminal (clamp) node for y2 # draw(fg) # draw the constructed factor graph # Perform sum-product message passing algo4 = messagePassingAlgorithm(x, id=:_x) source_code4 = algorithmSourceCode(algo4) eval(Meta.parse(source_code4)) x_marginal = step_x!(Dict())[:x] println("Sum-product message passing result: p(x|y1,y2) = 𝒩($(mean(x_marginal)),$(var(x_marginal)))") # Calculate mean and variance of p(x|y1,y2) manually by multiplying 3 Gaussians (see lesson 4 for details) v = 1 / (1/4 + 1/1 + 1/2) m = v * (0/4 + y1_hat/1.0 + y2_hat/2.0) println("Manual result: p(x|y1,y2) = 𝒩($(m), $(v))") # Backward message towards X fg = FactorGraph() x = Variable(id=:x) @RV y = constant(4.0) * x GaussianMeanVariance(y, constant(2.0), constant(1.0)) q = PosteriorFactorization(fg) algo5 = messagePassingAlgorithm(x, id=:_x_fwd2) source_code5 = algorithmSourceCode(algo5) eval(Meta.parse(source_code5)) msg_backward_X = step_x_fwd2!(Dict())[:x] print("Backward message on X: $(msg_backward_X)") open("../../styles/aipstyle.html") do f display("text/html", read(f, String)) end