which is computationally much cheaper than the general case above.
An FFG is an undirected graph subject to the following construction rules (Forney, 2001)
A configuration is an assigment of values to all variables.
A configuration $\omega=(x_1,x_2,x_3,x_4,x_5)$ is said to be valid iff $f(\omega) \neq 0$
where $$ f_=(x_2,x_2^\prime,x_2^{\prime\prime}) \triangleq \delta(x_2-x_2^\prime)\, \delta(x_2-x_2^{\prime\prime}) $$
it follows that any inference problem on $f$ can be executed by a corresponding inference problem on $g$, e.g., $$\begin{align*} f(x_1 \mid x_2) &\triangleq \frac{\iint f(x_1,x_2,x_3,x_4) \,\mathrm{d}x_3 \mathrm{d}x_4 }{ \idotsint f(x_1,x_2,x_3,x_4) \,\mathrm{d}x_1 \mathrm{d}x_3 \mathrm{d}x_4} \\ &= \frac{\idotsint g(x_1,x_2,x_2^\prime,x_2^{\prime\prime},x_3,x_4) \,\mathrm{d}x_2^\prime \mathrm{d}x_2^{\prime\prime} \mathrm{d}x_3 \mathrm{d}x_4 }{ \idotsint g(x_1,x_2,x_2^\prime,x_2^{\prime\prime},x_3,x_4) \,\mathrm{d}x_1 \mathrm{d}x_2^\prime \mathrm{d}x_2^{\prime\prime} \mathrm{d}x_3 \mathrm{d}x_4} \\ &= g(x_1 \mid x_2) \end{align*}$$
$f_a(x_1,x_2,x_3) \cdot f_b(x_3,x_4,x_5) \cdot f_c(x_4)$ could represent the probabilistic model $$ p(x_1,x_2,x_3,x_4,x_5) = p(x_1,x_2|x_3) \cdot p(x_3,x_5|x_4) \cdot p(x_4) $$ where we identify $$\begin{align*} f_a(x_1,x_2,x_3) &= p(x_1,x_2|x_3) \\ f_b(x_3,x_4,x_5) &= p(x_3,x_5|x_4) \\ f_c(x_4) &= p(x_4) \end{align*}$$
for a model $f$ with given factorization $$ f(x_1,x_2,\ldots,x_7) = f_a(x_1) f_b(x_2) f_c(x_1,x_2,x_3) f_d(x_4) f_e(x_3,x_4,x_5) f_f(x_5,x_6,x_7) f_g(x_7) $$
which is computationally (much) lighter than executing the full sum $\sum_{x_1,\ldots,x_7}f(x_1,x_2,\ldots,x_7)$
since there are no enclosed variables.
The foregoing message update rules can be worked out in closed-form and put into tables (e.g., see Tables 1 through 6 in Loeliger (2007) for many standard factors such as essential probability distributions and operations such as additions, fixed-gain multiplications and branching (equality nodes).
In the optional slides below, we have worked out a few of these update rules, eg, for the equality node, the addition node and the multiplication node.
If the update rules for all node types in a graph have been tabulated, then inference by message passing comes down to executing a set of table-lookup operations, thus creating a completely automatable Bayesian inference framework.
In our research lab BIASlab (FLUX 7.060), we are developing ForneyLab, which is a (Julia) toolbox for automating Bayesian inference by message passing in a factor graph.
or equivalently $$\begin{align*} p(w,\epsilon,D) &= \overbrace{p(w)}^{\text{weight prior}} \prod_{i=1}^N \overbrace{p(y_i\,|\,x_i,w,\epsilon_i)}^{\text{regression model}} \overbrace{p(\epsilon_i)}^{\text{noise model}} \\ &= \mathcal{N}(w\,|\,0,\Sigma) \prod_{i=1}^N \delta(y_i - w^T x_i - \epsilon_i) \mathcal{N}(\epsilon_i\,|\,0,\sigma^2) \end{align*}$$
using Pkg;Pkg.activate("probprog/workspace/");Pkg.instantiate();
IJulia.clear_output();
using PyPlot, ForneyLab, LinearAlgebra
# Parameters
Σ = 1e5 * Diagonal(I,3) # Covariance matrix of prior on w
σ2 = 2.0 # Noise variance
# Generate data set
w = [1.0; 2.0; 0.25]
N = 30
z = 10.0*rand(N)
x_train = [[1.0; z; z^2] for z in z] # Feature vector x = [1.0; z; z^2]
f(x) = (w'*x)[1]
y_train = map(f, x_train) + sqrt(σ2)*randn(N) # y[i] = w' * x[i] + ϵ
scatter(z, y_train); xlabel(L"z"); ylabel(L"f([1.0, z, z^2]) + \epsilon");
Now build the factor graph in ForneyLab, perform sum-product message passing and plot results (mean of posterior).
# Build factorgraph
fg = FactorGraph()
@RV w ~ GaussianMeanVariance(constant(zeros(3)), constant(Σ, id=:Σ), id=:w) # p(w)
for t=1:N
x_t = Variable(id=:x_*t)
d_t = Variable(id=:d_*t) # d=w'*x
DotProduct(d_t, x_t, w) # p(f|w,x)
@RV y_t ~ GaussianMeanVariance(d_t, constant(σ2, id=:σ2_*t), id=:y_*t) # p(y|d)
placeholder(x_t, :x, index=t, dims=(3,))
placeholder(y_t, :y, index=t);
end
# Build and run message passing algorithm
algo = messagePassingAlgorithm(w)
source_code = algorithmSourceCode(algo)
eval(Meta.parse(source_code))
data = Dict(:x => x_train, :y => y_train)
w_posterior_dist = step!(data)[:w]
# Plot result
println("Posterior distribution of w: $(w_posterior_dist)")
scatter(z, y_train); xlabel(L"z"); ylabel(L"f([1.0, z, z^2]) + \epsilon");
z_test = collect(0:0.2:12)
x_test = [[1.0; z; z^2] for z in z_test]
for sample=1:10
w = ForneyLab.sample(w_posterior_dist)
f_est(x) = (w'*x)[1]
plot(z_test, map(f_est, x_test), "k-", alpha=0.3);
end
Posterior distribution of w: 𝒩(xi=[2.71e+02, 1.71e+03, 1.20e+04], w=[[15.00, 72.34, 4.42e+02][72.34, 4.42e+02, 3.00e+03][4.42e+02, 3.00e+03, 2.20e+04]])
i.e., $\overrightarrow{\mu}_{Z}$ is the convolution of the messages $\overrightarrow{\mu}_{X}$ and $\overrightarrow{\mu}_{Y}$.
since $$\begin{align*} \overrightarrow{\mu}_{Y}(y) &= \overrightarrow{\mu}_{X}(A^{-1}y) \\ &\propto \exp \left( -\frac{1}{2} \left( A^{-1}y - \overrightarrow{m}_{X}\right)^T \overrightarrow{V}_{X}^{-1} \left( A^{-1}y - \overrightarrow{m}_{X}\right)\right) \\ &= \exp \big( -\frac{1}{2} \left( y - A\overrightarrow{m}_{X}\right)^T \underbrace{A^{-T}\overrightarrow{V}_{X}^{-1} A^{-1}}_{(A \overrightarrow{V}_{X} A^T)^{-1}} \left( y - A\overrightarrow{m}_{X}\right)\big) \\ &\propto \mathcal{N}(y| A\overrightarrow{m}_{X},A\overrightarrow{V}_{X}A^T) \,. \end{align*}$$
where $\overleftarrow{\xi}_X \triangleq \overleftarrow{W}_X \overleftarrow{m}_X$ and $\overleftarrow{W}_{X} \triangleq \overleftarrow{V}_{X}^{-1}$ (and similarly for $Y$).
Let's calculate the Gaussian forward and backward messages for the addition node in ForneyLab.
# Forward message towards Z
fg = FactorGraph()
@RV x ~ GaussianMeanVariance(constant(1.0), constant(1.0), id=:x)
@RV y ~ GaussianMeanVariance(constant(2.0), constant(1.0), id=:y)
@RV z = x + y; z.id = :z
q = PosteriorFactorization(fg)
algo1 = messagePassingAlgorithm(z, id=:_forward_Z)
source_code1 = algorithmSourceCode(algo1)
eval(Meta.parse(source_code1))
msg_forward_Z = step_forward_Z!(Dict())[:z]
print("Forward message on Z: $(msg_forward_Z)")
Forward message on Z: 𝒩(m=3.00, v=2.00)
# Backward message towards X
fg = FactorGraph()
@RV x; x.id=:x
@RV y ~ GaussianMeanVariance(constant(2.0), constant(1.0), id=:y)
@RV z = x + y
GaussianMeanVariance(z, constant(3.0), constant(1.0), id=:z)
q = PosteriorFactorization(fg)
algo2 = messagePassingAlgorithm(x, id=:_backward_X)
source_code2 = algorithmSourceCode(algo2)
eval(Meta.parse(source_code2))
msg_backward_X = step_backward_X!(Dict())[:x]
print("Backward message on X: $(msg_backward_X)")
Backward message on X: 𝒩(m=1.00, v=2.00)
In the same way we can also investigate the forward and backward messages for the matrix multiplication ("gain") node
# Forward message towards Y
fg = FactorGraph()
@RV x ~ GaussianMeanVariance(1.0, 1.0)
@RV y = 4.0 * x
q = PosteriorFactorization(fg)
#This is where the bugs live..
algo3 = messagePassingAlgorithm(y, id=:_y_fwd)
source_code3 = algorithmSourceCode(algo3)
eval(Meta.parse(source_code3))
msg_forward_Y = step_y_fwd!(Dict())[:y]
print("Forward message on Y: $(msg_forward_Y)")
Forward message on Y: 𝒩(m=4.00, v=16.00)
We'll use ForneyLab to build the above graph, and perform sum-product message passing to infer the posterior $p(x|y_1,y_2)$. We assume $p(y_1|x)$ and $p(y_2|x)$ to be Gaussian likelihoods with known variances: $$\begin{align*} p(y_1\,|\,x) &= \mathcal{N}(y_1\,|\,x, v_{y1}) \\ p(y_2\,|\,x) &= \mathcal{N}(y_2\,|\,x, v_{y2}) \end{align*}$$ Under this model, the posterior is given by: $$\begin{align*} p(x\,|\,y_1,y_2) &\propto \overbrace{p(y_1\,|\,x)\,p(y_2\,|\,x)}^{\text{likelihood}}\,\overbrace{p(x)}^{\text{prior}} \\ &=\mathcal{N}(x\,|\,\hat{y}_1, v_{y1})\, \mathcal{N}(x\,|\,\hat{y}_2, v_{y2}) \, \mathcal{N}(x\,|\,m_x, v_x) \end{align*}$$ so we can validate the answer by solving the Gaussian multiplication manually.
using ForneyLab
# Data
y1_hat = 1.0
y2_hat = 2.0
# Construct the factor graph
fg = FactorGraph()
@RV x ~ GaussianMeanVariance(constant(0.0), constant(4.0), id=:x) # Node p(x)
@RV y1 ~ GaussianMeanVariance(x, constant(1.0)) # Node p(y1|x)
@RV y2 ~ GaussianMeanVariance(x, constant(2.0)) # Node p(y2|x)
Clamp(y1, y1_hat) # Terminal (clamp) node for y1
Clamp(y2, y2_hat) # Terminal (clamp) node for y2
# draw(fg) # draw the constructed factor graph
# Perform sum-product message passing
algo4 = messagePassingAlgorithm(x, id=:_x)
source_code4 = algorithmSourceCode(algo4)
eval(Meta.parse(source_code4))
x_marginal = step_x!(Dict())[:x]
println("Sum-product message passing result: p(x|y1,y2) = 𝒩($(mean(x_marginal)),$(var(x_marginal)))")
# Calculate mean and variance of p(x|y1,y2) manually by multiplying 3 Gaussians (see lesson 4 for details)
v = 1 / (1/4 + 1/1 + 1/2)
m = v * (0/4 + y1_hat/1.0 + y2_hat/2.0)
println("Manual result: p(x|y1,y2) = 𝒩($(m), $(v))")
Sum-product message passing result: p(x|y1,y2) = 𝒩(1.1428571428571428,0.5714285714285714) Manual result: p(x|y1,y2) = 𝒩(1.1428571428571428, 0.5714285714285714)
# Backward message towards X
fg = FactorGraph()
x = Variable(id=:x)
@RV y = constant(4.0) * x
GaussianMeanVariance(y, constant(2.0), constant(1.0))
q = PosteriorFactorization(fg)
algo5 = messagePassingAlgorithm(x, id=:_x_fwd2)
source_code5 = algorithmSourceCode(algo5)
eval(Meta.parse(source_code5))
msg_backward_X = step_x_fwd2!(Dict())[:x]
print("Backward message on X: $(msg_backward_X)")
Backward message on X: 𝒩(xi=8.00, w=16.00)
open("../../styles/aipstyle.html") do f display("text/html", read(f, String)) end