which is computationally much cheaper than the general case above.
An FFG is an undirected graph subject to the following construction rules (Forney, 2001)
A configuration is an assigment of values to all variables.
A configuration $\omega=(x_1,x_2,x_3,x_4,x_5)$ is said to be valid iff $f(\omega) \neq 0$
where $$ f_=(x_2,x_2^\prime,x_2^{\prime\prime}) \triangleq \delta(x_2-x_2^\prime)\, \delta(x_2-x_2^{\prime\prime}) $$
it follows that any inference problem on $f$ can be executed by a corresponding inference problem on $g$, e.g., $$\begin{align*} f(x_1 \mid x_2) &\triangleq \frac{\iint f(x_1,x_2,x_3,x_4) \,\mathrm{d}x_3 \mathrm{d}x_4 }{ \int\cdots\int f(x_1,x_2,x_3,x_4) \,\mathrm{d}x_1 \mathrm{d}x_3 \mathrm{d}x_4} \\ &= \frac{\int\cdots\int g(x_1,x_2,x_2^\prime,x_2^{\prime\prime},x_3,x_4) \,\mathrm{d}x_2^\prime \mathrm{d}x_2^{\prime\prime} \mathrm{d}x_3 \mathrm{d}x_4 }{ \int\cdots\int g(x_1,x_2,x_2^\prime,x_2^{\prime\prime},x_3,x_4) \,\mathrm{d}x_1 \mathrm{d}x_2^\prime \mathrm{d}x_2^{\prime\prime} \mathrm{d}x_3 \mathrm{d}x_4} \\ &= g(x_1 \mid x_2) \end{align*}$$
$f_a(x_1,x_2,x_3) \cdot f_b(x_3,x_4,x_5) \cdot f_c(x_4)$ could represent the probabilistic model $$ p(x_1,x_2,x_3,x_4,x_5) = p(x_1,x_2|x_3) \cdot p(x_3,x_5|x_4) \cdot p(x_4) $$ where we identify $$\begin{align*} f_a(x_1,x_2,x_3) &= p(x_1,x_2|x_3) \\ f_b(x_3,x_4,x_5) &= p(x_3,x_5|x_4) \\ f_c(x_4) &= p(x_4) \end{align*}$$
for a model $f$ with given factorization $$ f(x_1,x_2,\ldots,x_7) = f_a(x_1) f_b(x_2) f_c(x_1,x_2,x_3) f_d(x_4) f_e(x_3,x_4,x_5) f_f(x_5,x_6,x_7) f_g(x_7) $$
which is computationally (much) lighter than executing the full sum $\sum_{x_1,\ldots,x_7}f(x_1,x_2,\ldots,x_7)$
by the following product-of-sums: $$ (a + b)(c + d) \,.$$ Which of these two computations is cheaper to execute?
since there are no enclosed variables.
The foregoing message update rules can be worked out in closed-form and put into tables (e.g., see Tables 1 through 6 in Loeliger (2007) for many standard factors such as essential probability distributions and operations such as additions, fixed-gain multiplications and branching (equality nodes).
In the optional slides below, we have worked out a few more update rules for the addition node and the multiplication node.
If the update rules for all node types in a graph have been tabulated, then inference by message passing comes down to executing a set of table-lookup operations, thus creating a completely automatable Bayesian inference framework.
In our research lab BIASlab (FLUX 7.060), we are developing RxInfer, which is a (Julia) toolbox for automating Bayesian inference by message passing in a factor graph.
or equivalently $$\begin{align*} p(w,\epsilon,D) &= \overbrace{p(w)}^{\text{weight prior}} \prod_{i=1}^N \overbrace{p(y_i\,|\,x_i,w,\epsilon_i)}^{\text{regression model}} \overbrace{p(\epsilon_i)}^{\text{noise model}} \\ &= \mathcal{N}(w\,|\,0,\Sigma) \prod_{i=1}^N \delta(y_i - w^T x_i - \epsilon_i) \mathcal{N}(\epsilon_i\,|\,0,\sigma^2) \end{align*}$$
using PyPlot, LinearAlgebra
# Parameters
Σ = 1e5 * Diagonal(I,3) # Covariance matrix of prior on w
σ2 = 2.0 # Noise variance
# Generate data set
w = [1.0; 2.0; 0.25]
N = 30
z = 10.0*rand(N)
x_train = [[1.0; z; z^2] for z in z] # Feature vector x = [1.0; z; z^2]
f(x) = (w'*x)[1]
y_train = map(f, x_train) + sqrt(σ2)*randn(N) # y[i] = w' * x[i] + ϵ
scatter(z, y_train); xlabel(L"z"); ylabel(L"f([1.0, z, z^2]) + \epsilon");
Now build the factor graph in RxInfer, perform sum-product message passing and plot results (mean of posterior).
using RxInfer, Random
# Build model
@model function linear_regression(N, Σ, σ2)
w ~ MvNormalMeanCovariance(constvar(zeros(3)),constvar(Σ))
x = datavar(Vector{Float64}, N)
y = datavar(Float64, N)
for i in 1:N
y[i] ~ NormalMeanVariance(dot(w , x[i]), σ2)
end
return w, x, y
end
# Run message passing algorithm
results = inference(
model = linear_regression(length(x_train), Σ, σ2 ),
data = (y = y_train, x = x_train),
returnvars = (w = KeepLast()),
iterations = 20
);
# Plot result
w = results.posteriors[:w]
println("Posterior distribution of w: $(w)")
scatter(z, y_train); xlabel(L"z"); ylabel(L"f([1.0, z, z^2]) + ϵ");
z_test = collect(0:0.2:12)
x_test = [[1.0; z; z^2] for z in z_test]
for i=1:10
w_sample = rand(results.posteriors[:w])
f_est(x) = (w_sample'*x)[1]
plot(z_test, map(f_est, x_test), "k-", alpha=0.3);
end
Posterior distribution of w: MvNormalWeightedMeanPrecision( xi: [318.19138919704756, 2358.5409231141653, 19126.67282976907] Λ: [15.000010000030002 81.69926298019712 580.4592797749924; 81.69926298019712 580.4592897750223 4589.625579040317; 580.4592797749924 4589.625579040317 38295.96922430482] )
The great Michael Jordan (no, not this one, but this one), wrote:
I basically know of two principles for treating complicated systems in simple ways: the first is the principle of modularity and the second is the principle of abstraction. I am an apologist for computational probability in machine learning because I believe that probability theory implements these two principles in deep and intriguing ways — namely through factorization and through averaging. Exploiting these two mechanisms as fully as possible seems to me to be the way forward in machine learning. — Michael Jordan, 1997 (quoted in Fre98).
Factor graphs realize these ideas nicely, both visually and computationally.
Visually, the modularity of conditional independencies in the model are displayed by the graph structure. Each node hides internal complexity and by closing-the-box, we can hierarchically move on to higher levels of abstraction.
Computationally, message passing-based inference uses the Distributive Law to avoid any unnecessary computations.
since $$\begin{align*} \overrightarrow{\mu}_{Y}(y) &= |A|^{-1}\overrightarrow{\mu}_{X}(A^{-1}y) \\ &\propto \exp \left( -\frac{1}{2} \left( A^{-1}y - \overrightarrow{m}_{X}\right)^T \overrightarrow{V}_{X}^{-1} \left( A^{-1}y - \overrightarrow{m}_{X}\right)\right) \\ &= \exp \big( -\frac{1}{2} \left( y - A\overrightarrow{m}_{X}\right)^T \underbrace{A^{-T}\overrightarrow{V}_{X}^{-1} A^{-1}}_{(A \overrightarrow{V}_{X} A^T)^{-1}} \left( y - A\overrightarrow{m}_{X}\right)\big) \\ &\propto \mathcal{N}(y| A\overrightarrow{m}_{X},A\overrightarrow{V}_{X}A^T) \,. \end{align*}$$
where $\overleftarrow{\xi}_X \triangleq \overleftarrow{W}_X \overleftarrow{m}_X$ and $\overleftarrow{W}_{X} \triangleq \overleftarrow{V}_{X}^{-1}$ (and similarly for $Y$).
Let's calculate the Gaussian forward and backward messages for the addition node in RxInfer.
println("Forward message on Z:")
@call_rule typeof(+)(:out, Marginalisation) (m_in1 = NormalMeanVariance(1.0, 1.0), m_in2 = NormalMeanVariance(2.0, 1.0))
Forward message on Z:
NormalMeanVariance{Float64}(μ=3.0, v=2.0)
println("Backward message on X:")
@call_rule typeof(+)(:in1, Marginalisation) (m_out = NormalMeanVariance(3.0, 1.0), m_in2 = NormalMeanVariance(2.0, 1.0))
Backward message on X:
NormalMeanVariance{Float64}(μ=1.0, v=2.0)
In the same way we can also investigate the forward and backward messages for the matrix multiplication ("gain") node
println("Forward message on Y:")
@call_rule typeof(*)(:out, Marginalisation) (m_A = PointMass(4.0), m_in = NormalMeanVariance(1.0, 1.0))
Forward message on Y:
NormalMeanVariance{Float64}(μ=4.0, v=16.0)
println("Backward message on X:")
@call_rule typeof(*)(:in, Marginalisation) (m_out = NormalMeanVariance(2.0, 1.0), m_A = PointMass(4.0), meta = TinyCorrection())
Backward message on X:
NormalWeightedMeanPrecision{Float64}(xi=8.0, w=16.0)
We'll use RxInfer to build the above graph, and perform sum-product message passing to infer the posterior $p(x|y_1,y_2)$. We assume $p(y_1|x)$ and $p(y_2|x)$ to be Gaussian likelihoods with known variances: $$\begin{align*} p(y_1\,|\,x) &= \mathcal{N}(y_1\,|\,x, v_{y1}) \\ p(y_2\,|\,x) &= \mathcal{N}(y_2\,|\,x, v_{y2}) \end{align*}$$ Under this model, the posterior is given by: $$\begin{align*} p(x\,|\,y_1,y_2) &\propto \overbrace{p(y_1\,|\,x)\,p(y_2\,|\,x)}^{\text{likelihood}}\,\overbrace{p(x)}^{\text{prior}} \\ &=\mathcal{N}(x\,|\,\hat{y}_1, v_{y1})\, \mathcal{N}(x\,|\,\hat{y}_2, v_{y2}) \, \mathcal{N}(x\,|\,m_x, v_x) \end{align*}$$ so we can validate the answer by solving the Gaussian multiplication manually.
# Data
y1_hat = 1.0
y2_hat = 2.0
# Construct the factor graph
@model function my_model()
# `x` is the hidden states
x = randomvar()
# `y1` and `y2` are "clamped" observations
y1 = datavar(Float64,)
y2 = datavar(Float64,)
x ~ NormalMeanVariance(constvar(0.0), constvar(4.0))
y1 ~ NormalMeanVariance(x, constvar(1))
y2 ~ NormalMeanVariance(x, constvar(2))
return x
end
result = inference(model=my_model(), data=(y1=y1_hat, y2 = y2_hat,))
println("Sum-product message passing result: p(x|y1,y2) = 𝒩($(mean(result.posteriors[:x])),$(var(result.posteriors[:x])))")
# Calculate mean and variance of p(x|y1,y2) manually by multiplying 3 Gaussians (see lesson 4 for details)
v = 1 / (1/4 + 1/1 + 1/2)
m = v * (0/4 + y1_hat/1.0 + y2_hat/2.0)
println("Manual result: p(x|y1,y2) = 𝒩($(m), $(v))")
Sum-product message passing result: p(x|y1,y2) = 𝒩(1.1428571428571428,0.5714285714285714) Manual result: p(x|y1,y2) = 𝒩(1.1428571428571428, 0.5714285714285714)
open("../../styles/aipstyle.html") do f display("text/html", read(f, String)) end